diff --git a/.gitignore b/.gitignore index 8a71b76..e69de29 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +0,0 @@ -_examples \ No newline at end of file diff --git a/LICENSE b/LICENSE index 672a19d..46a15a0 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -kinetra.de/net License +github.com/deneonet/knet License Version: 1.0.0 Copyright (c) 2024 deneonet diff --git a/README.md b/README.md index 3fbbdde..bfb97df 100644 --- a/README.md +++ b/README.md @@ -1,33 +1,33 @@ -# kinetra.de/net +# github.com/deneonet/knet A library to handle secure TCP connections over a custom protocol. -**Warning:** This library is not complete and contains unfinished code. +**Warning:** This library is not complete and may contain unfinished code. ## Security -**kinetra.de/net** uses AES-256 for data encryption, ECDH P521 for key exchange, and ECDSA P521 for signing. +**github.com/deneonet/knet** uses AES-256 for data encryption, ECDH P521 for key exchange, and ECDSA P521 for signing. ## Data Encoding -To efficiently encode data, **kinetra.de/net** utilizes [benc](https://github.com/deneonet/benc) as its serializer. +To efficiently encode data, **github.com/deneonet/knet** utilizes [benc](https://github.com/deneonet/benc) as its serializer. ## Key Rotations -If the private key of the server's certificate is compromised or just expired, simply generate a new one and update the client's root key as well. **cosair.gg** encodes a version field into the certificate and root key to verify that the client is always in sync with the server. If the versions do not match, a clear error will be returned. +If the private key of the server's certificate is compromised or just expired, simply generate a new one and update the client's root key as well. **kNet** encodes a version field into the certificate and root key to verify that the client is always in sync with the server. If the versions do not match, a clear error will be returned. ## Generating Certificates -As simple as `go run kinetra.de/net/gen -v {VERSION_NUMBER}`, everything is done locally on your machine. To ensure that the root key is in sync with the certificate, it will be generated as well. +As simple as `go run github.com/deneonet/knet/gen -v {VERSION_NUMBER}`, everything is done locally on your machine. To ensure that the root key is in sync with the certificate, it will be generated as well. ## The Handshake Process -1. **[Client]**: I want your certificate to prove your identity as **[Server]**. -2. **[Server]**: Sure, here’s my certificate. -3. **[Client]**: I'll check the signature using my root key, verifying that the public key was not compromised, is not expired, and matches the expected version. -4. **[Client]**: I verified it; it's valid. Here’s my public key. I'll create a shared secret using your public key. -5. **[Server]**: I have the shared secret now too. Let me verify that we have the same and that your public key was not compromised as well. I’ll send you an encrypted message using the shared secret. +1. **[Client]**: I want your certificate to prove your identity as **[Server]**. +2. **[Server]**: Sure, here’s my certificate. +3. **[Client]**: I'll check the signature using my root key, verifying that the public key was not compromised, is not expired, and matches the expected version. +4. **[Client]**: I verified it; it's valid. Here’s my public key. I'll create a shared secret using your public key. +5. **[Server]**: I have the shared secret now too. Let me verify that we have the same and that your public key was not compromised as well. I’ll send you an encrypted message using the shared secret. 6. **[Client]**: I successfully decrypted the message. Our connection is now secure! -## Examples +## Real-time chat Example -Find examples [here](https://github.com/deneonet/cosair.gg-net-examples). +Find a real-time chat example [here](https://github.com/deneonet/knet-real-time-chat). diff --git a/cert/cert.benc.go b/cert/cert.benc.go deleted file mode 100644 index 8ed247c..0000000 --- a/cert/cert.benc.go +++ /dev/null @@ -1,184 +0,0 @@ -// Code generated by bencgen golang. DO NOT EDIT. -// source: ../schemas/Certificate.benc - -package cert - -import ( - "github.com/deneonet/benc/std" - "github.com/deneonet/benc/impl/gen" -) - -// Struct - Certificate -type Certificate struct { - PrivateKey []byte - PublicKey []byte - PublicKeySignature []byte - Version int - CreatedAt int64 -} - -// Reserved Ids - Certificate -var certificateRIds = []uint16{} - -// Size - Certificate -func (certificate *Certificate) Size() int { - return certificate.size(0) -} - -// Nested Size - Certificate -func (certificate *Certificate) size(id uint16) (s int) { - s += bstd.SizeBytes(certificate.PrivateKey) + 2 - s += bstd.SizeBytes(certificate.PublicKey) + 2 - s += bstd.SizeBytes(certificate.PublicKeySignature) + 2 - s += bstd.SizeInt(certificate.Version) + 2 - s += bstd.SizeInt64() + 2 - - if id > 255 { - s += 5 - return - } - s += 4 - return -} - -// SizePlain - Certificate -func (certificate *Certificate) SizePlain() (s int) { - s += bstd.SizeBytes(certificate.PrivateKey) - s += bstd.SizeBytes(certificate.PublicKey) - s += bstd.SizeBytes(certificate.PublicKeySignature) - s += bstd.SizeInt(certificate.Version) - s += bstd.SizeInt64() - return -} - -// Marshal - Certificate -func (certificate *Certificate) Marshal(b []byte) { - certificate.marshal(0, b, 0) -} - -// Nested Marshal - Certificate -func (certificate *Certificate) marshal(tn int, b []byte, id uint16) (n int) { - n = bgenimpl.MarshalTag(tn, b, bgenimpl.Container, id) - n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 1) - n = bstd.MarshalBytes(n, b, certificate.PrivateKey) - n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 2) - n = bstd.MarshalBytes(n, b, certificate.PublicKey) - n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 3) - n = bstd.MarshalBytes(n, b, certificate.PublicKeySignature) - n = bgenimpl.MarshalTag(n, b, bgenimpl.Varint, 4) - n = bstd.MarshalInt(n, b, certificate.Version) - n = bgenimpl.MarshalTag(n, b, bgenimpl.Fixed64, 5) - n = bstd.MarshalInt64(n, b, certificate.CreatedAt) - - n += 2 - b[n-2] = 1 - b[n-1] = 1 - return -} - -// MarshalPlain - Certificate -func (certificate *Certificate) MarshalPlain(tn int, b []byte) (n int) { - n = tn - n = bstd.MarshalBytes(n, b, certificate.PrivateKey) - n = bstd.MarshalBytes(n, b, certificate.PublicKey) - n = bstd.MarshalBytes(n, b, certificate.PublicKeySignature) - n = bstd.MarshalInt(n, b, certificate.Version) - n = bstd.MarshalInt64(n, b, certificate.CreatedAt) - return n -} - -// Unmarshal - Certificate -func (certificate *Certificate) Unmarshal(b []byte) (err error) { - _, err = certificate.unmarshal(0, b, []uint16{}, 0) - return -} - -// Nested Unmarshal - Certificate -func (certificate *Certificate) unmarshal(tn int, b []byte, r []uint16, id uint16) (n int, err error) { - var ok bool - if n, ok, err = bgenimpl.HandleCompatibility(tn, b, r, id); !ok { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if n, ok, err = bgenimpl.HandleCompatibility(n, b, certificateRIds, 1); err != nil { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if ok { - if n, certificate.PrivateKey, err = bstd.UnmarshalBytes(n, b); err != nil { - return - } - } - if n, ok, err = bgenimpl.HandleCompatibility(n, b, certificateRIds, 2); err != nil { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if ok { - if n, certificate.PublicKey, err = bstd.UnmarshalBytes(n, b); err != nil { - return - } - } - if n, ok, err = bgenimpl.HandleCompatibility(n, b, certificateRIds, 3); err != nil { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if ok { - if n, certificate.PublicKeySignature, err = bstd.UnmarshalBytes(n, b); err != nil { - return - } - } - if n, ok, err = bgenimpl.HandleCompatibility(n, b, certificateRIds, 4); err != nil { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if ok { - if n, certificate.Version, err = bstd.UnmarshalInt(n, b); err != nil { - return - } - } - if n, ok, err = bgenimpl.HandleCompatibility(n, b, certificateRIds, 5); err != nil { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if ok { - if n, certificate.CreatedAt, err = bstd.UnmarshalInt64(n, b); err != nil { - return - } - } - n += 2 - return -} - -// UnmarshalPlain - Certificate -func (certificate *Certificate) UnmarshalPlain(tn int, b []byte) (n int, err error) { - n = tn - if n, certificate.PrivateKey, err = bstd.UnmarshalBytes(n, b); err != nil { - return - } - if n, certificate.PublicKey, err = bstd.UnmarshalBytes(n, b); err != nil { - return - } - if n, certificate.PublicKeySignature, err = bstd.UnmarshalBytes(n, b); err != nil { - return - } - if n, certificate.Version, err = bstd.UnmarshalInt(n, b); err != nil { - return - } - if n, certificate.CreatedAt, err = bstd.UnmarshalInt64(n, b); err != nil { - return - } - return -} - diff --git a/cert/root.benc.go b/cert/root.benc.go deleted file mode 100644 index 4d78b9b..0000000 --- a/cert/root.benc.go +++ /dev/null @@ -1,144 +0,0 @@ -// Code generated by bencgen golang. DO NOT EDIT. -// source: ../schemas/RootKey.benc - -package cert - -import ( - "github.com/deneonet/benc/std" - "github.com/deneonet/benc/impl/gen" -) - -// Struct - RootKey -type RootKey struct { - PublicKey []byte - Version int - CreatedAt int64 -} - -// Reserved Ids - RootKey -var rootKeyRIds = []uint16{} - -// Size - RootKey -func (rootKey *RootKey) Size() int { - return rootKey.size(0) -} - -// Nested Size - RootKey -func (rootKey *RootKey) size(id uint16) (s int) { - s += bstd.SizeBytes(rootKey.PublicKey) + 2 - s += bstd.SizeInt(rootKey.Version) + 2 - s += bstd.SizeInt64() + 2 - - if id > 255 { - s += 5 - return - } - s += 4 - return -} - -// SizePlain - RootKey -func (rootKey *RootKey) SizePlain() (s int) { - s += bstd.SizeBytes(rootKey.PublicKey) - s += bstd.SizeInt(rootKey.Version) - s += bstd.SizeInt64() - return -} - -// Marshal - RootKey -func (rootKey *RootKey) Marshal(b []byte) { - rootKey.marshal(0, b, 0) -} - -// Nested Marshal - RootKey -func (rootKey *RootKey) marshal(tn int, b []byte, id uint16) (n int) { - n = bgenimpl.MarshalTag(tn, b, bgenimpl.Container, id) - n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 1) - n = bstd.MarshalBytes(n, b, rootKey.PublicKey) - n = bgenimpl.MarshalTag(n, b, bgenimpl.Varint, 2) - n = bstd.MarshalInt(n, b, rootKey.Version) - n = bgenimpl.MarshalTag(n, b, bgenimpl.Fixed64, 3) - n = bstd.MarshalInt64(n, b, rootKey.CreatedAt) - - n += 2 - b[n-2] = 1 - b[n-1] = 1 - return -} - -// MarshalPlain - RootKey -func (rootKey *RootKey) MarshalPlain(tn int, b []byte) (n int) { - n = tn - n = bstd.MarshalBytes(n, b, rootKey.PublicKey) - n = bstd.MarshalInt(n, b, rootKey.Version) - n = bstd.MarshalInt64(n, b, rootKey.CreatedAt) - return n -} - -// Unmarshal - RootKey -func (rootKey *RootKey) Unmarshal(b []byte) (err error) { - _, err = rootKey.unmarshal(0, b, []uint16{}, 0) - return -} - -// Nested Unmarshal - RootKey -func (rootKey *RootKey) unmarshal(tn int, b []byte, r []uint16, id uint16) (n int, err error) { - var ok bool - if n, ok, err = bgenimpl.HandleCompatibility(tn, b, r, id); !ok { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if n, ok, err = bgenimpl.HandleCompatibility(n, b, rootKeyRIds, 1); err != nil { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if ok { - if n, rootKey.PublicKey, err = bstd.UnmarshalBytes(n, b); err != nil { - return - } - } - if n, ok, err = bgenimpl.HandleCompatibility(n, b, rootKeyRIds, 2); err != nil { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if ok { - if n, rootKey.Version, err = bstd.UnmarshalInt(n, b); err != nil { - return - } - } - if n, ok, err = bgenimpl.HandleCompatibility(n, b, rootKeyRIds, 3); err != nil { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if ok { - if n, rootKey.CreatedAt, err = bstd.UnmarshalInt64(n, b); err != nil { - return - } - } - n += 2 - return -} - -// UnmarshalPlain - RootKey -func (rootKey *RootKey) UnmarshalPlain(tn int, b []byte) (n int, err error) { - n = tn - if n, rootKey.PublicKey, err = bstd.UnmarshalBytes(n, b); err != nil { - return - } - if n, rootKey.Version, err = bstd.UnmarshalInt(n, b); err != nil { - return - } - if n, rootKey.CreatedAt, err = bstd.UnmarshalInt64(n, b); err != nil { - return - } - return -} - diff --git a/cert/rootkey.go b/cert/rootkey.go new file mode 100644 index 0000000..4c4ef04 --- /dev/null +++ b/cert/rootkey.go @@ -0,0 +1,146 @@ +// Code generated by bencgen go. DO NOT EDIT. +// source: client_root_key.benc + +package cert + +import ( + "github.com/deneonet/benc/std" + "github.com/deneonet/benc/impl/gen" + + +) + +// Struct - ClientRootKey +type ClientRootKey struct { + Key []byte + Version int + CreatedAt int64 +} + +// Reserved Ids - ClientRootKey +var clientRootKeyRIds = []uint16{} + +// Size - ClientRootKey +func (clientRootKey *ClientRootKey) Size() int { + return clientRootKey.NestedSize(0) +} + +// Nested Size - ClientRootKey +func (clientRootKey *ClientRootKey) NestedSize(id uint16) (s int) { + s += bstd.SizeBytes(clientRootKey.Key) + 2 + s += bstd.SizeInt(clientRootKey.Version) + 2 + s += bstd.SizeInt64() + 2 + + if id > 255 { + s += 5 + return + } + s += 4 + return +} + +// SizePlain - ClientRootKey +func (clientRootKey *ClientRootKey) SizePlain() (s int) { + s += bstd.SizeBytes(clientRootKey.Key) + s += bstd.SizeInt(clientRootKey.Version) + s += bstd.SizeInt64() + return +} + +// Marshal - ClientRootKey +func (clientRootKey *ClientRootKey) Marshal(b []byte) { + clientRootKey.NestedMarshal(0, b, 0) +} + +// Nested Marshal - ClientRootKey +func (clientRootKey *ClientRootKey) NestedMarshal(tn int, b []byte, id uint16) (n int) { + n = bgenimpl.MarshalTag(tn, b, bgenimpl.Container, id) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 1) + n = bstd.MarshalBytes(n, b, clientRootKey.Key) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Varint, 2) + n = bstd.MarshalInt(n, b, clientRootKey.Version) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Fixed64, 3) + n = bstd.MarshalInt64(n, b, clientRootKey.CreatedAt) + + n += 2 + b[n-2] = 1 + b[n-1] = 1 + return +} + +// MarshalPlain - ClientRootKey +func (clientRootKey *ClientRootKey) MarshalPlain(tn int, b []byte) (n int) { + n = tn + n = bstd.MarshalBytes(n, b, clientRootKey.Key) + n = bstd.MarshalInt(n, b, clientRootKey.Version) + n = bstd.MarshalInt64(n, b, clientRootKey.CreatedAt) + return n +} + +// Unmarshal - ClientRootKey +func (clientRootKey *ClientRootKey) Unmarshal(b []byte) (err error) { + _, err = clientRootKey.NestedUnmarshal(0, b, []uint16{}, 0) + return +} + +// Nested Unmarshal - ClientRootKey +func (clientRootKey *ClientRootKey) NestedUnmarshal(tn int, b []byte, r []uint16, id uint16) (n int, err error) { + var ok bool + if n, ok, err = bgenimpl.HandleCompatibility(tn, b, r, id); !ok { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, clientRootKeyRIds, 1); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, clientRootKey.Key, err = bstd.UnmarshalBytes(n, b); err != nil { + return + } + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, clientRootKeyRIds, 2); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, clientRootKey.Version, err = bstd.UnmarshalInt(n, b); err != nil { + return + } + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, clientRootKeyRIds, 3); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, clientRootKey.CreatedAt, err = bstd.UnmarshalInt64(n, b); err != nil { + return + } + } + n += 2 + return +} + +// UnmarshalPlain - ClientRootKey +func (clientRootKey *ClientRootKey) UnmarshalPlain(tn int, b []byte) (n int, err error) { + n = tn + if n, clientRootKey.Key, err = bstd.UnmarshalBytes(n, b); err != nil { + return + } + if n, clientRootKey.Version, err = bstd.UnmarshalInt(n, b); err != nil { + return + } + if n, clientRootKey.CreatedAt, err = bstd.UnmarshalInt64(n, b); err != nil { + return + } + return +} + diff --git a/cert/servercert.go b/cert/servercert.go new file mode 100644 index 0000000..f0bb109 --- /dev/null +++ b/cert/servercert.go @@ -0,0 +1,186 @@ +// Code generated by bencgen go. DO NOT EDIT. +// source: server_certificate.benc + +package cert + +import ( + "github.com/deneonet/benc/std" + "github.com/deneonet/benc/impl/gen" + + +) + +// Struct - ServerCertificate +type ServerCertificate struct { + PublicKey []byte + PrivateKey []byte + PublicKeySignature []byte + Version int + CreatedAt int64 +} + +// Reserved Ids - ServerCertificate +var serverCertificateRIds = []uint16{} + +// Size - ServerCertificate +func (serverCertificate *ServerCertificate) Size() int { + return serverCertificate.NestedSize(0) +} + +// Nested Size - ServerCertificate +func (serverCertificate *ServerCertificate) NestedSize(id uint16) (s int) { + s += bstd.SizeBytes(serverCertificate.PublicKey) + 2 + s += bstd.SizeBytes(serverCertificate.PrivateKey) + 2 + s += bstd.SizeBytes(serverCertificate.PublicKeySignature) + 2 + s += bstd.SizeInt(serverCertificate.Version) + 2 + s += bstd.SizeInt64() + 2 + + if id > 255 { + s += 5 + return + } + s += 4 + return +} + +// SizePlain - ServerCertificate +func (serverCertificate *ServerCertificate) SizePlain() (s int) { + s += bstd.SizeBytes(serverCertificate.PublicKey) + s += bstd.SizeBytes(serverCertificate.PrivateKey) + s += bstd.SizeBytes(serverCertificate.PublicKeySignature) + s += bstd.SizeInt(serverCertificate.Version) + s += bstd.SizeInt64() + return +} + +// Marshal - ServerCertificate +func (serverCertificate *ServerCertificate) Marshal(b []byte) { + serverCertificate.NestedMarshal(0, b, 0) +} + +// Nested Marshal - ServerCertificate +func (serverCertificate *ServerCertificate) NestedMarshal(tn int, b []byte, id uint16) (n int) { + n = bgenimpl.MarshalTag(tn, b, bgenimpl.Container, id) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 1) + n = bstd.MarshalBytes(n, b, serverCertificate.PublicKey) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 2) + n = bstd.MarshalBytes(n, b, serverCertificate.PrivateKey) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 3) + n = bstd.MarshalBytes(n, b, serverCertificate.PublicKeySignature) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Varint, 4) + n = bstd.MarshalInt(n, b, serverCertificate.Version) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Fixed64, 5) + n = bstd.MarshalInt64(n, b, serverCertificate.CreatedAt) + + n += 2 + b[n-2] = 1 + b[n-1] = 1 + return +} + +// MarshalPlain - ServerCertificate +func (serverCertificate *ServerCertificate) MarshalPlain(tn int, b []byte) (n int) { + n = tn + n = bstd.MarshalBytes(n, b, serverCertificate.PublicKey) + n = bstd.MarshalBytes(n, b, serverCertificate.PrivateKey) + n = bstd.MarshalBytes(n, b, serverCertificate.PublicKeySignature) + n = bstd.MarshalInt(n, b, serverCertificate.Version) + n = bstd.MarshalInt64(n, b, serverCertificate.CreatedAt) + return n +} + +// Unmarshal - ServerCertificate +func (serverCertificate *ServerCertificate) Unmarshal(b []byte) (err error) { + _, err = serverCertificate.NestedUnmarshal(0, b, []uint16{}, 0) + return +} + +// Nested Unmarshal - ServerCertificate +func (serverCertificate *ServerCertificate) NestedUnmarshal(tn int, b []byte, r []uint16, id uint16) (n int, err error) { + var ok bool + if n, ok, err = bgenimpl.HandleCompatibility(tn, b, r, id); !ok { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, serverCertificateRIds, 1); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, serverCertificate.PublicKey, err = bstd.UnmarshalBytes(n, b); err != nil { + return + } + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, serverCertificateRIds, 2); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, serverCertificate.PrivateKey, err = bstd.UnmarshalBytes(n, b); err != nil { + return + } + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, serverCertificateRIds, 3); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, serverCertificate.PublicKeySignature, err = bstd.UnmarshalBytes(n, b); err != nil { + return + } + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, serverCertificateRIds, 4); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, serverCertificate.Version, err = bstd.UnmarshalInt(n, b); err != nil { + return + } + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, serverCertificateRIds, 5); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, serverCertificate.CreatedAt, err = bstd.UnmarshalInt64(n, b); err != nil { + return + } + } + n += 2 + return +} + +// UnmarshalPlain - ServerCertificate +func (serverCertificate *ServerCertificate) UnmarshalPlain(tn int, b []byte) (n int, err error) { + n = tn + if n, serverCertificate.PublicKey, err = bstd.UnmarshalBytes(n, b); err != nil { + return + } + if n, serverCertificate.PrivateKey, err = bstd.UnmarshalBytes(n, b); err != nil { + return + } + if n, serverCertificate.PublicKeySignature, err = bstd.UnmarshalBytes(n, b); err != nil { + return + } + if n, serverCertificate.Version, err = bstd.UnmarshalInt(n, b); err != nil { + return + } + if n, serverCertificate.CreatedAt, err = bstd.UnmarshalInt64(n, b); err != nil { + return + } + return +} + diff --git a/cert/utils.go b/cert/utils.go index 8423d97..39a33ba 100644 --- a/cert/utils.go +++ b/cert/utils.go @@ -1,5 +1,3 @@ -//go:generate bencgen --in ../schemas/RootKey.benc --out . --file root.benc --lang go -//go:generate bencgen --in ../schemas/Certificate.benc --out . --file cert.benc --lang go package cert import ( @@ -13,35 +11,46 @@ import ( ) var ( - ErrFailedVerification = errors.New("public key of certificate couldn't be verified") - ErrCertificateExpired = errors.New("certificate has expired") - ErrRootKeyExpired = errors.New("root key has expired") + ErrClientRootKeyExpired = errors.New("root key has expired") + ErrServerCertificateExpired = errors.New("certificate has expired") + ErrVersionMismatch = errors.New("certificate and root key are not in sync") + ErrFailedVerification = errors.New("public key of certificate couldn't be verified") - ErrVersionMismatch = errors.New("certificate and root key are not in sync") + ErrCertificateSigningFailed = errors.New("failed to sign certificate") - oneYear = 365 * 24 * time.Hour // one year in hours + ErrECDSARootKeyGenerationFailed = errors.New("failed to generate ECDSA root key") + ErrECDHPrivateKeyGenerationFailed = errors.New("failed to generate ECDH private key") + ErrServerPublicKeyCreationFailed = errors.New("failed to create server's public key") + + // TODO: Make expiration time configurable + DefaultCertificateExpiry = 365 * 24 * time.Hour // Default one year in hours ) -func VerifyCertificate(cert Certificate, root RootKey) (*ecdh.PublicKey, error) { +// VerifyCertificate verifies the certificate with the provided root key, returning the public key. +func VerifyCertificate(cert ServerCertificate, root ClientRootKey, expiry time.Duration) (*ecdh.PublicKey, error) { if cert.Version != root.Version { return nil, ErrVersionMismatch } - if time.Now().Unix()-cert.CreatedAt > int64(oneYear.Seconds()) { - return nil, ErrCertificateExpired + if expiry == 0 { + expiry = DefaultCertificateExpiry + } + + if time.Now().Unix()-cert.CreatedAt > int64(expiry.Seconds()) { + return nil, ErrServerCertificateExpired } - if time.Now().Unix()-root.CreatedAt > int64(oneYear.Seconds()) { - return nil, ErrRootKeyExpired + if time.Now().Unix()-root.CreatedAt > int64(expiry.Seconds()) { + return nil, ErrClientRootKeyExpired } publicKey, err := ecdh.P521().NewPublicKey(cert.PublicKey) if err != nil { - return nil, err + return nil, ErrServerPublicKeyCreationFailed } rootKey := new(ecdsa.PublicKey) rootKey.Curve = elliptic.P521() - rootKey.X, rootKey.Y = elliptic.UnmarshalCompressed(elliptic.P521(), root.PublicKey) + rootKey.X, rootKey.Y = elliptic.UnmarshalCompressed(elliptic.P521(), root.Key) if !ecdsa.VerifyASN1(rootKey, publicKey.Bytes(), cert.PublicKeySignature) { return nil, ErrFailedVerification @@ -50,35 +59,34 @@ func VerifyCertificate(cert Certificate, root RootKey) (*ecdh.PublicKey, error) return publicKey, nil } +// GenerateCertificateChain generates a certificate chain and stores it in the specified file paths. func GenerateCertificateChain(version int, certFilePath string, rootKeyFilePath string) error { privateKey, err := ecdh.P521().GenerateKey(rand.Reader) if err != nil { - return err + return ErrECDHPrivateKeyGenerationFailed } rootKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) if err != nil { - return err + return ErrECDSARootKeyGenerationFailed } publicKey := privateKey.PublicKey().Bytes() signature, err := ecdsa.SignASN1(rand.Reader, rootKey, publicKey) if err != nil { - return err + return ErrCertificateSigningFailed } - cert := Certificate{ - PrivateKey: privateKey.Bytes(), + cert := ServerCertificate{ PublicKey: publicKey, PublicKeySignature: signature, - - Version: version, - CreatedAt: time.Now().Unix(), + PrivateKey: privateKey.Bytes(), + Version: version, + CreatedAt: time.Now().Unix(), } - root := RootKey{ - PublicKey: elliptic.MarshalCompressed(elliptic.P521(), rootKey.PublicKey.X, rootKey.PublicKey.Y), - + root := ClientRootKey{ + Key: elliptic.MarshalCompressed(elliptic.P521(), rootKey.PublicKey.X, rootKey.PublicKey.Y), Version: version, CreatedAt: time.Now().Unix(), } @@ -86,26 +94,32 @@ func GenerateCertificateChain(version int, certFilePath string, rootKeyFilePath certData := make([]byte, cert.Size()) cert.Marshal(certData) - certFile, err := os.Create(certFilePath) - if err != nil { + if err := writeToFile(certFilePath, certData); err != nil { return err } - defer certFile.Close() - if _, err = certFile.Write(certData); err != nil { + rootKeyData := make([]byte, root.Size()) + root.Marshal(rootKeyData) + + if err := writeToFile(rootKeyFilePath, rootKeyData); err != nil { return err } - rootKeyData := make([]byte, root.Size()) - root.Marshal(rootKeyData) + return nil +} - rootKeyFile, err := os.Create(rootKeyFilePath) +func writeToFile(filePath string, data []byte) error { + file, err := os.Create(filePath) if err != nil { return err } - defer rootKeyFile.Close() + defer file.Close() + + if _, err := file.Write(data); err != nil { + return err + } - if _, err = rootKeyFile.Write(rootKeyData); err != nil { + if err := os.Chmod(filePath, 0600); err != nil { return err } diff --git a/client.go b/client.go index 02070ad..730066d 100644 --- a/client.go +++ b/client.go @@ -10,10 +10,10 @@ import ( "sync" "time" - "kinetra.de/net/cert" - "kinetra.de/net/crypto" - "kinetra.de/net/netutils" - "kinetra.de/net/packets" + "github.com/deneonet/knet/cert" + "github.com/deneonet/knet/crypto" + "github.com/deneonet/knet/handshake" + "github.com/deneonet/knet/netutils" bstd "github.com/deneonet/benc/std" ) @@ -22,11 +22,13 @@ type Client struct { priv *ecdh.PrivateKey session ClientSession - rootKey cert.RootKey + rootKey cert.ClientRootKey RootKeyFile string mutex *sync.Mutex + netUtilsSettings netutils.NetUtilsSettings + ReadDeadline time.Duration WriteDeadline time.Duration @@ -53,43 +55,26 @@ const ( ClientHandshakeError ) -func (c *Client) setDeadline(conn net.Conn, handshakeComplete bool) { - conn.SetDeadline(time.Time{}) - - var readDeadline, writeDeadline time.Duration - if handshakeComplete { - readDeadline, writeDeadline = c.ReadDeadline, c.WriteDeadline - } else { - readDeadline, writeDeadline = c.HandshakeReadDeadline, c.HandshakeWriteDeadline - } - - if readDeadline > 0 { - conn.SetReadDeadline(time.Now().Add(readDeadline)) - } - if writeDeadline > 0 { - conn.SetWriteDeadline(time.Now().Add(writeDeadline)) - } -} - func (c *Client) processHandshake(conn net.Conn, buf []byte) (ClientHandshakeResult, error) { - packet, err := packets.UnmarshalHandshakePacket(buf) + packet, err := handshake.UnmarshalHandshakePacket(buf) if err != nil { return ClientHandshakeError, err } - switch packets.HandshakePacketId(packet.Id) { - case packets.CertificateResponse: - var certificate cert.Certificate + switch packet.Type { + case handshake.PacketTypeCertificateResponse: + var certificate cert.ServerCertificate if err := certificate.Unmarshal(packet.Payload); err != nil { return ClientHandshakeError, err } - serverPublicKey, err := cert.VerifyCertificate(certificate, c.rootKey) + // TODO: Custom expiry + serverPublicKey, err := cert.VerifyCertificate(certificate, c.rootKey, 0) if err != nil { return ClientHandshakeError, err } - if err = packets.SendHandshakePacket(conn, packets.ClientInformation, c.priv.PublicKey().Bytes()); err != nil { + if err = handshake.SendHandshakePacket(conn, handshake.PacketTypeClientInformation, c.priv.PublicKey().Bytes(), &c.netUtilsSettings); err != nil { return ClientHandshakeError, err } @@ -99,15 +84,17 @@ func (c *Client) processHandshake(conn net.Conn, buf []byte) (ClientHandshakeRes } aesSecret := sha256.Sum256(sharedSecret) + c.mutex.Lock() c.session = ClientSession{ Conn: conn, SharedSecret: aesSecret[:], } + c.mutex.Unlock() return ClientContinueHandshake, nil - case packets.ServerVerification: + case handshake.PacketTypeServerVerification: if _, err := crypto.Decrypt(c.session.SharedSecret, packet.Payload); err != nil { - return ClientHandshakeError, ErrDecryptingVerificationId + return ClientHandshakeError, ErrDecryptingServerVerification } return ClientHandshakeComplete, nil @@ -117,11 +104,11 @@ func (c *Client) processHandshake(conn net.Conn, buf []byte) (ClientHandshakeRes } func (c *Client) Connect(address string) (net.Conn, error) { - if c.HandshakeReadDeadline == 0 { - c.HandshakeReadDeadline = 500 * time.Millisecond - } - if c.HandshakeWriteDeadline == 0 { - c.HandshakeWriteDeadline = 500 * time.Millisecond + c.netUtilsSettings = netutils.NetUtilsSettings{ + HandshakeReadDeadline: c.HandshakeReadDeadline, + HandshakeWriteDeadline: c.HandshakeWriteDeadline, + ReadDeadline: c.ReadDeadline, + WriteDeadline: c.WriteDeadline, } c.mutex = &sync.Mutex{} @@ -143,11 +130,9 @@ func (c *Client) Connect(address string) (net.Conn, error) { return nil, err } - defer func() { - conn.Close() - }() + defer conn.Close() - if err = packets.SendHandshakePacket(conn, packets.CertificateRequest, nil); err != nil { + if err = handshake.SendHandshakePacket(conn, handshake.PacketTypeCertificateRequest, nil, &c.netUtilsSettings); err != nil { return nil, err } @@ -155,8 +140,7 @@ func (c *Client) Connect(address string) (net.Conn, error) { buf := make([]byte, c.BufferSize) for { - c.setDeadline(conn, handshakeComplete) - s, err := netutils.ReadFromConn(conn, buf) + s, err := netutils.ReadFromConn(conn, buf, &c.netUtilsSettings) if err != nil && !handshakeComplete { return nil, err @@ -166,6 +150,7 @@ func (c *Client) Connect(address string) (net.Conn, error) { if !handshakeComplete { result, err = c.processHandshake(conn, buf[4:s]) if err != nil { + conn.Close() return nil, err } @@ -177,7 +162,8 @@ func (c *Client) Connect(address string) (net.Conn, error) { handshakeComplete = true s = 0 - c.setDeadline(conn, handshakeComplete) + c.netUtilsSettings.HandshakeCompleted = true + if c.OnSecureConnect != nil { if action := c.OnSecureConnect(c.session); action == Close { break @@ -216,7 +202,7 @@ func (c *Client) Connect(address string) (net.Conn, error) { } } - return nil, err + return conn, nil } func (c *Client) Send(b []byte) error { @@ -224,11 +210,17 @@ func (c *Client) Send(b []byte) error { if err != nil { return err } - return netutils.SendToConn(c.session.Conn, len(encrypted)+1, func(n int, b []byte) { b[n] = 1; copy(b[n+1:], encrypted) }) + return netutils.SendToConn(c.session.Conn, len(encrypted)+1, func(n int, b []byte) { + b[n] = 1 + copy(b[n+1:], encrypted) + }, &c.netUtilsSettings) } func (c *Client) SendUnsecure(b []byte) error { - return netutils.SendToConn(c.session.Conn, len(b)+1, func(n int, buf []byte) { buf[n] = 0; copy(buf[n+1:], b) }) + return netutils.SendToConn(c.session.Conn, len(b)+1, func(n int, buf []byte) { + buf[n] = 0 + copy(buf[n+1:], b) + }, &c.netUtilsSettings) } func (c *Client) SendPacket(id int, p Packet) error { diff --git a/common.go b/common.go index 1a4a140..e23561c 100644 --- a/common.go +++ b/common.go @@ -23,9 +23,9 @@ const ( ) var ( - ErrSessionNotFound = errors.New("no session with that remote address found") - ErrDecryptingVerificationId = errors.New("error decrypting verification id") - ErrInvalidHandshakePacket = errors.New("invalid handshake packet received") - ErrBufTooSmall = errors.New("buffer is too small for the requested size") - ErrInvalidRootKey = errors.New("invalid root key in client struct") + ErrSessionNotFound = errors.New("no session associated with that connection found") + ErrInvalidRootKey = errors.New("invalid root key in client struct") + ErrInvalidHandshakePacket = errors.New("invalid handshake packet received") + ErrDecryptingServerVerification = errors.New("error decrypting server verification") + ErrDataExceededBufferSize = errors.New("received data size exceeded buffer size") ) diff --git a/crypto/encryption.go b/crypto/encryption.go index 2cfa0d5..d2fa7c9 100644 --- a/crypto/encryption.go +++ b/crypto/encryption.go @@ -12,6 +12,8 @@ var ( ErrCipherTextTooShort = errors.New("cipher text length is smaller than nonce length") ) +// Encrypt encrypts the provided data using the given AES key and returns the ciphertext +// which includes the nonce as the first part of the result. func Encrypt(key, data []byte) ([]byte, error) { block, err := aes.NewCipher(key) if err != nil { @@ -32,6 +34,8 @@ func Encrypt(key, data []byte) ([]byte, error) { return cipherText, nil } +// Decrypt decrypts the provided ciphertext using the given AES key and returns the decrypted data. +// The ciphertext must include the nonce at the start. func Decrypt(key, cipherData []byte) ([]byte, error) { block, err := aes.NewCipher(key) if err != nil { @@ -49,6 +53,5 @@ func Decrypt(key, cipherData []byte) ([]byte, error) { } nonce, cipherText := cipherData[:nonceSize], cipherData[nonceSize:] - plainData, err := aesGCM.Open(nil, nonce, cipherText, nil) - return plainData, err + return aesGCM.Open(nil, nonce, cipherText, nil) } diff --git a/gen/main.go b/gen/main.go index ba67089..c2820b1 100644 --- a/gen/main.go +++ b/gen/main.go @@ -4,16 +4,16 @@ import ( "flag" "fmt" - "kinetra.de/net/cert" + "github.com/deneonet/knet/cert" ) func main() { - version := flag.Int("v", 0x01, "Certificate And Root Key version") + version := flag.Int("v", 0x01, "Server certificate and client root key version.") flag.Parse() fmt.Printf("Generating for version: %d\n", *version) - if err := cert.GenerateCertificateChain(*version, "s.cert", "root.key"); err != nil { + if err := cert.GenerateCertificateChain(*version, "server.kc", "client.kr"); err != nil { panic(err) } } diff --git a/go.mod b/go.mod index 91266a3..4daf140 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,9 @@ -module kinetra.de/net +module github.com/deneonet/knet go 1.23.2 +require github.com/deneonet/benc v1.1.6 + require ( - github.com/deneonet/benc v1.1.2 golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect ) diff --git a/go.sum b/go.sum index bddf9cd..d0c225f 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,10 @@ github.com/deneonet/benc v1.1.2 h1:JNJSnA53zVLjt4Bz1HwxG4tQg475LP+kd8rgUuV4tc4= github.com/deneonet/benc v1.1.2/go.mod h1:HbL4lzHT0jkmlYa36bZw0a0Nhj4NsXG7bd/bXRxJYy4= +github.com/deneonet/benc v1.1.4 h1:h88Ghu8TL3vaVhnAtsJv+nd66Fr6t0foeeZuxs7U+5Y= +github.com/deneonet/benc v1.1.4/go.mod h1:L61vicZVEHmk7l4FsBaI6S3AhLR1viG7sGsnw0PgGXA= +github.com/deneonet/benc v1.1.6 h1:+Uo+/8ABnV7SH7JsGdnHwD8zgSfIjDyKIz0CxWxNdlQ= +github.com/deneonet/benc v1.1.6/go.mod h1:UCfkM5Od0B2huwv/ZItvtUb7QnALFt9YXtX8NXX4Lts= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ= golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE= diff --git a/handshake/packet.go b/handshake/packet.go new file mode 100644 index 0000000..6e59cf5 --- /dev/null +++ b/handshake/packet.go @@ -0,0 +1,135 @@ +// Code generated by bencgen go. DO NOT EDIT. +// source: handshake_packet.benc + +package handshake + +import ( + "github.com/deneonet/benc/std" + "github.com/deneonet/benc/impl/gen" + + +) + +// Enum - PacketType +type PacketType int +const ( + PacketTypeCertificateRequest PacketType = iota + PacketTypeCertificateResponse + PacketTypeServerVerification + PacketTypeClientInformation +) + +// Struct - Packet +type Packet struct { + Payload []byte + Type PacketType +} + +// Reserved Ids - Packet +var packetRIds = []uint16{} + +// Size - Packet +func (packet *Packet) Size() int { + return packet.NestedSize(0) +} + +// Nested Size - Packet +func (packet *Packet) NestedSize(id uint16) (s int) { + s += bstd.SizeBytes(packet.Payload) + 2 + s += bgenimpl.SizeEnum(packet.Type) + 2 + + if id > 255 { + s += 5 + return + } + s += 4 + return +} + +// SizePlain - Packet +func (packet *Packet) SizePlain() (s int) { + s += bstd.SizeBytes(packet.Payload) + s += bgenimpl.SizeEnum(packet.Type) + return +} + +// Marshal - Packet +func (packet *Packet) Marshal(b []byte) { + packet.NestedMarshal(0, b, 0) +} + +// Nested Marshal - Packet +func (packet *Packet) NestedMarshal(tn int, b []byte, id uint16) (n int) { + n = bgenimpl.MarshalTag(tn, b, bgenimpl.Container, id) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 1) + n = bstd.MarshalBytes(n, b, packet.Payload) + n = bgenimpl.MarshalTag(n, b, bgenimpl.ArrayMap, 2) + n = bgenimpl.MarshalEnum(n, b, packet.Type) + + n += 2 + b[n-2] = 1 + b[n-1] = 1 + return +} + +// MarshalPlain - Packet +func (packet *Packet) MarshalPlain(tn int, b []byte) (n int) { + n = tn + n = bstd.MarshalBytes(n, b, packet.Payload) + n = bgenimpl.MarshalEnum(n, b, packet.Type) + return n +} + +// Unmarshal - Packet +func (packet *Packet) Unmarshal(b []byte) (err error) { + _, err = packet.NestedUnmarshal(0, b, []uint16{}, 0) + return +} + +// Nested Unmarshal - Packet +func (packet *Packet) NestedUnmarshal(tn int, b []byte, r []uint16, id uint16) (n int, err error) { + var ok bool + if n, ok, err = bgenimpl.HandleCompatibility(tn, b, r, id); !ok { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, packetRIds, 1); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, packet.Payload, err = bstd.UnmarshalBytes(n, b); err != nil { + return + } + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, packetRIds, 2); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, packet.Type, err = bgenimpl.UnmarshalEnum[PacketType](n, b); err != nil { + return + } + } + n += 2 + return +} + +// UnmarshalPlain - Packet +func (packet *Packet) UnmarshalPlain(tn int, b []byte) (n int, err error) { + n = tn + if n, packet.Payload, err = bstd.UnmarshalBytes(n, b); err != nil { + return + } + if n, packet.Type, err = bgenimpl.UnmarshalEnum[PacketType](n, b); err != nil { + return + } + return +} + diff --git a/handshake/utils.go b/handshake/utils.go new file mode 100644 index 0000000..e0465a3 --- /dev/null +++ b/handshake/utils.go @@ -0,0 +1,35 @@ +package handshake + +import ( + "net" + + "github.com/deneonet/knet/cert" + "github.com/deneonet/knet/netutils" +) + +// SendHandshakePacket sends a handshake packet with the given type and payload. +func SendHandshakePacket(conn net.Conn, typ PacketType, payload []byte, netUtilsSettings *netutils.NetUtilsSettings) error { + packet := Packet{payload, typ} + s := packet.Size() + return netutils.SendToConn(conn, s, func(n int, b []byte) { + packet.NestedMarshal(n, b, 0) + }, netUtilsSettings) +} + +// SendCertResponseHandshakePacket sends a certificate response handshake packet +// to the connection, ensuring that the certificate's private key is cleared before transmission. +func SendCertResponseHandshakePacket(conn net.Conn, cert cert.ServerCertificate, netUtilsSettings *netutils.NetUtilsSettings) error { + cert.PrivateKey = nil // scary! + b := make([]byte, cert.Size()) + cert.Marshal(b) + return SendHandshakePacket(conn, PacketTypeCertificateResponse, b, netUtilsSettings) +} + +// UnmarshalHandshakePacket unmarshals the given buffer into a Handshake packet. +func UnmarshalHandshakePacket(buf []byte) (packet Packet, err error) { + err = packet.Unmarshal(buf) + if err != nil { + return packet, err + } + return packet, nil +} diff --git a/netutils/conn.go b/netutils/conn.go index eea8fca..0cd288c 100644 --- a/netutils/conn.go +++ b/netutils/conn.go @@ -1,21 +1,27 @@ package netutils import ( - "errors" "io" "net" + "time" "github.com/deneonet/benc" bstd "github.com/deneonet/benc/std" ) -var ( - ErrBufTooSmall = errors.New("buffer is too small for the requested size") -) +type NetUtilsSettings struct { + HandshakeReadDeadline time.Duration + HandshakeWriteDeadline time.Duration + + ReadDeadline time.Duration + WriteDeadline time.Duration + + HandshakeCompleted bool +} func readFull(r io.Reader, s int, buf []byte) (n int, err error) { if len(buf) < s { - return 0, ErrBufTooSmall + return 0, benc.ErrBufTooSmall } for n < s && err == nil { @@ -34,7 +40,35 @@ func readFull(r io.Reader, s int, buf []byte) (n int, err error) { return n, err } -func ReadFromConn(conn net.Conn, buf []byte) (s uint32, err error) { +func setReadDeadline(conn net.Conn, settings *NetUtilsSettings) { + conn.SetReadDeadline(time.Time{}) + + readDeadline := settings.ReadDeadline + if !settings.HandshakeCompleted { + readDeadline = settings.HandshakeReadDeadline + } + + if readDeadline > 0 { + conn.SetReadDeadline(time.Now().Add(readDeadline)) + } +} + +func setWriteDeadline(conn net.Conn, settings *NetUtilsSettings) { + conn.SetWriteDeadline(time.Time{}) + + writeDeadline := settings.WriteDeadline + if !settings.HandshakeCompleted { + writeDeadline = settings.HandshakeWriteDeadline + } + + if writeDeadline > 0 { + conn.SetWriteDeadline(time.Now().Add(writeDeadline)) + } +} + +func ReadFromConn(conn net.Conn, buf []byte, settings *NetUtilsSettings) (s uint32, err error) { + setReadDeadline(conn, settings) + if _, err = readFull(conn, 4, buf); err != nil { return } @@ -52,10 +86,12 @@ func ReadFromConn(conn net.Conn, buf []byte) (s uint32, err error) { return } -// TODO: buffer size +// Buffer pool for efficient memory management, TODO: Custom buffer size var bufPool = benc.NewBufPool(benc.WithBufferSize(4092 * 2 * 2 * 2 * 2)) -func SendToConn(conn net.Conn, s int, f func(n int, b []byte)) (err error) { +func SendToConn(conn net.Conn, s int, f func(n int, b []byte), settings *NetUtilsSettings) (err error) { + setWriteDeadline(conn, settings) + fs := s + bstd.SizeUint32() _, errT := bufPool.Marshal(fs, func(b []byte) (n int) { @@ -64,6 +100,7 @@ func SendToConn(conn net.Conn, s int, f func(n int, b []byte)) (err error) { _, err = conn.Write(b) return }) + if err != nil { return } diff --git a/packets/handshake.benc.go b/packets/handshake.benc.go deleted file mode 100644 index 7b68979..0000000 --- a/packets/handshake.benc.go +++ /dev/null @@ -1,124 +0,0 @@ -// Code generated by bencgen golang. DO NOT EDIT. -// source: ../schemas/Handshake.benc - -package packets - -import ( - "github.com/deneonet/benc/std" - "github.com/deneonet/benc/impl/gen" -) - -// Struct - HandshakePacket -type HandshakePacket struct { - Id byte - Payload []byte -} - -// Reserved Ids - HandshakePacket -var handshakePacketRIds = []uint16{} - -// Size - HandshakePacket -func (handshakePacket *HandshakePacket) Size() int { - return handshakePacket.size(0) -} - -// Nested Size - HandshakePacket -func (handshakePacket *HandshakePacket) size(id uint16) (s int) { - s += bstd.SizeByte() + 2 - s += bstd.SizeBytes(handshakePacket.Payload) + 2 - - if id > 255 { - s += 5 - return - } - s += 4 - return -} - -// SizePlain - HandshakePacket -func (handshakePacket *HandshakePacket) SizePlain() (s int) { - s += bstd.SizeByte() - s += bstd.SizeBytes(handshakePacket.Payload) - return -} - -// Marshal - HandshakePacket -func (handshakePacket *HandshakePacket) Marshal(b []byte) { - handshakePacket.marshal(0, b, 0) -} - -// Nested Marshal - HandshakePacket -func (handshakePacket *HandshakePacket) marshal(tn int, b []byte, id uint16) (n int) { - n = bgenimpl.MarshalTag(tn, b, bgenimpl.Container, id) - n = bgenimpl.MarshalTag(n, b, bgenimpl.Fixed8, 1) - n = bstd.MarshalByte(n, b, handshakePacket.Id) - n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 2) - n = bstd.MarshalBytes(n, b, handshakePacket.Payload) - - n += 2 - b[n-2] = 1 - b[n-1] = 1 - return -} - -// MarshalPlain - HandshakePacket -func (handshakePacket *HandshakePacket) MarshalPlain(tn int, b []byte) (n int) { - n = tn - n = bstd.MarshalByte(n, b, handshakePacket.Id) - n = bstd.MarshalBytes(n, b, handshakePacket.Payload) - return n -} - -// Unmarshal - HandshakePacket -func (handshakePacket *HandshakePacket) Unmarshal(b []byte) (err error) { - _, err = handshakePacket.unmarshal(0, b, []uint16{}, 0) - return -} - -// Nested Unmarshal - HandshakePacket -func (handshakePacket *HandshakePacket) unmarshal(tn int, b []byte, r []uint16, id uint16) (n int, err error) { - var ok bool - if n, ok, err = bgenimpl.HandleCompatibility(tn, b, r, id); !ok { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if n, ok, err = bgenimpl.HandleCompatibility(n, b, handshakePacketRIds, 1); err != nil { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if ok { - if n, handshakePacket.Id, err = bstd.UnmarshalByte(n, b); err != nil { - return - } - } - if n, ok, err = bgenimpl.HandleCompatibility(n, b, handshakePacketRIds, 2); err != nil { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if ok { - if n, handshakePacket.Payload, err = bstd.UnmarshalBytes(n, b); err != nil { - return - } - } - n += 2 - return -} - -// UnmarshalPlain - HandshakePacket -func (handshakePacket *HandshakePacket) UnmarshalPlain(tn int, b []byte) (n int, err error) { - n = tn - if n, handshakePacket.Id, err = bstd.UnmarshalByte(n, b); err != nil { - return - } - if n, handshakePacket.Payload, err = bstd.UnmarshalBytes(n, b); err != nil { - return - } - return -} - diff --git a/packets/handshake.go b/packets/handshake.go deleted file mode 100644 index afc5189..0000000 --- a/packets/handshake.go +++ /dev/null @@ -1,41 +0,0 @@ -//go:generate bencgen --in ../schemas/Handshake.benc --out . --file handshake.benc --lang go -package packets - -import ( - "net" - - "kinetra.de/net/cert" - "kinetra.de/net/netutils" -) - -type HandshakePacketId byte - -const ( - CertificateRequest HandshakePacketId = iota - CertificateResponse - - ServerVerification - ClientInformation -) - -func SendHandshakePacket(conn net.Conn, id HandshakePacketId, payload []byte) error { - packet := HandshakePacket{byte(id), payload} - s := packet.Size() - return netutils.SendToConn(conn, s, func(n int, b []byte) { packet.marshal(n, b, 0) }) -} - -func SendCertResponseHandshakePacket(conn net.Conn, id HandshakePacketId, cert cert.Certificate) error { - cert.PrivateKey = nil // scary! - - b := make([]byte, cert.Size()) - cert.Marshal(b) - - packet := HandshakePacket{byte(id), b} - s := packet.Size() - return netutils.SendToConn(conn, s, func(n int, b []byte) { packet.marshal(n, b, 0) }) -} - -func UnmarshalHandshakePacket(buf []byte) (packet HandshakePacket, err error) { - err = packet.Unmarshal(buf) - return -} diff --git a/schemas/Certificate.benc b/schemas/Certificate.benc deleted file mode 100644 index 6aeb8a9..0000000 --- a/schemas/Certificate.benc +++ /dev/null @@ -1,13 +0,0 @@ -header cert; - -ctr Certificate { - bytes privateKey = 1; - bytes publicKey = 2; - bytes publicKeySignature = 3; - - int version = 4; - int64 createdAt = 5; -} - -# DO NOT EDIT. -# [meta_s] eyJtc2dzIjp7IkNlcnRpZmljYXRlIjp7InJJZHMiOm51bGwsImZpZWxkcyI6eyIxIjp7IklkIjoxLCJOYW1lIjoicHJpdmF0ZUtleSIsIlR5cGUiOnsiVG9rZW5UeXBlIjoxNCwiTWFwS2V5VHlwZSI6bnVsbCwiQ2hpbGRUeXBlIjpudWxsLCJDdHJOYW1lIjoiIiwiSXNVbnNhZmUiOmZhbHNlLCJJc0FycmF5IjpmYWxzZSwiSXNNYXAiOmZhbHNlfX0sIjIiOnsiSWQiOjIsIk5hbWUiOiJwdWJsaWNLZXkiLCJUeXBlIjp7IlRva2VuVHlwZSI6MTQsIk1hcEtleVR5cGUiOm51bGwsIkNoaWxkVHlwZSI6bnVsbCwiQ3RyTmFtZSI6IiIsIklzVW5zYWZlIjpmYWxzZSwiSXNBcnJheSI6ZmFsc2UsIklzTWFwIjpmYWxzZX19LCIzIjp7IklkIjozLCJOYW1lIjoicHVibGljS2V5U2lnbmF0dXJlIiwiVHlwZSI6eyJUb2tlblR5cGUiOjE0LCJNYXBLZXlUeXBlIjpudWxsLCJDaGlsZFR5cGUiOm51bGwsIkN0ck5hbWUiOiIiLCJJc1Vuc2FmZSI6ZmFsc2UsIklzQXJyYXkiOmZhbHNlLCJJc01hcCI6ZmFsc2V9fSwiNCI6eyJJZCI6NCwiTmFtZSI6InZlcnNpb24iLCJUeXBlIjp7IlRva2VuVHlwZSI6OSwiTWFwS2V5VHlwZSI6bnVsbCwiQ2hpbGRUeXBlIjpudWxsLCJDdHJOYW1lIjoiIiwiSXNVbnNhZmUiOmZhbHNlLCJJc0FycmF5IjpmYWxzZSwiSXNNYXAiOmZhbHNlfX0sIjUiOnsiSWQiOjUsIk5hbWUiOiJjcmVhdGVkQXQiLCJUeXBlIjp7IlRva2VuVHlwZSI6NiwiTWFwS2V5VHlwZSI6bnVsbCwiQ2hpbGRUeXBlIjpudWxsLCJDdHJOYW1lIjoiIiwiSXNVbnNhZmUiOmZhbHNlLCJJc0FycmF5IjpmYWxzZSwiSXNNYXAiOmZhbHNlfX19fX19 [meta_e] \ No newline at end of file diff --git a/schemas/Handshake.benc b/schemas/Handshake.benc deleted file mode 100644 index 675b0f6..0000000 --- a/schemas/Handshake.benc +++ /dev/null @@ -1,9 +0,0 @@ -header packets; - -ctr HandshakePacket { - byte id = 1; - bytes payload = 2; -} - -# DO NOT EDIT. -# [meta_s] eyJtc2dzIjp7IkhhbmRzaGFrZVBhY2tldCI6eyJySWRzIjpudWxsLCJmaWVsZHMiOnsiMSI6eyJJZCI6MSwiTmFtZSI6ImlkIiwiVHlwZSI6eyJUb2tlblR5cGUiOjE5LCJNYXBLZXlUeXBlIjpudWxsLCJDaGlsZFR5cGUiOm51bGwsIkN0ck5hbWUiOiIiLCJJc1Vuc2FmZSI6ZmFsc2UsIklzQXJyYXkiOmZhbHNlLCJJc01hcCI6ZmFsc2V9fSwiMiI6eyJJZCI6MiwiTmFtZSI6InBheWxvYWQiLCJUeXBlIjp7IlRva2VuVHlwZSI6MTQsIk1hcEtleVR5cGUiOm51bGwsIkNoaWxkVHlwZSI6bnVsbCwiQ3RyTmFtZSI6IiIsIklzVW5zYWZlIjpmYWxzZSwiSXNBcnJheSI6ZmFsc2UsIklzTWFwIjpmYWxzZX19fX19fQ== [meta_e] \ No newline at end of file diff --git a/schemas/RootKey.benc b/schemas/RootKey.benc deleted file mode 100644 index f0ef4e3..0000000 --- a/schemas/RootKey.benc +++ /dev/null @@ -1,12 +0,0 @@ -header cert; - -ctr RootKey { - bytes publicKey = 1; - - int version = 2; - int64 createdAt = 3; -} - - -# DO NOT EDIT. -# [meta_s] eyJtc2dzIjp7IlJvb3RLZXkiOnsicklkcyI6bnVsbCwiZmllbGRzIjp7IjEiOnsiSWQiOjEsIk5hbWUiOiJwdWJsaWNLZXkiLCJUeXBlIjp7IlRva2VuVHlwZSI6MTQsIk1hcEtleVR5cGUiOm51bGwsIkNoaWxkVHlwZSI6bnVsbCwiQ3RyTmFtZSI6IiIsIklzVW5zYWZlIjpmYWxzZSwiSXNBcnJheSI6ZmFsc2UsIklzTWFwIjpmYWxzZX19LCIyIjp7IklkIjoyLCJOYW1lIjoidmVyc2lvbiIsIlR5cGUiOnsiVG9rZW5UeXBlIjo5LCJNYXBLZXlUeXBlIjpudWxsLCJDaGlsZFR5cGUiOm51bGwsIkN0ck5hbWUiOiIiLCJJc1Vuc2FmZSI6ZmFsc2UsIklzQXJyYXkiOmZhbHNlLCJJc01hcCI6ZmFsc2V9fSwiMyI6eyJJZCI6MywiTmFtZSI6ImNyZWF0ZWRBdCIsIlR5cGUiOnsiVG9rZW5UeXBlIjo2LCJNYXBLZXlUeXBlIjpudWxsLCJDaGlsZFR5cGUiOm51bGwsIkN0ck5hbWUiOiIiLCJJc1Vuc2FmZSI6ZmFsc2UsIklzQXJyYXkiOmZhbHNlLCJJc01hcCI6ZmFsc2V9fX19fX0= [meta_e] \ No newline at end of file diff --git a/schemas/client_root_key.benc b/schemas/client_root_key.benc new file mode 100644 index 0000000..b6f2524 --- /dev/null +++ b/schemas/client_root_key.benc @@ -0,0 +1,13 @@ +define cert; + +var go_package = "github.com/deneonet/knet/cert"; + +ctr ClientRootKey { + bytes key = 1; + + int version = 2; + int64 createdAt = 3; +} + +# DO NOT EDIT. +# [meta_s] eyJtc2dzIjp7IkNsaWVudFJvb3RLZXkiOnsicklkcyI6bnVsbCwiZmllbGRzIjp7IjEiOnsiaWQiOjEsIk5hbWUiOiJrZXkiLCJUeXBlIjp7IlRva2VuVHlwZSI6MTksIk1hcEtleVR5cGUiOm51bGwsIkNoaWxkVHlwZSI6bnVsbCwiY3RyTmFtZSI6IiIsIklzVW5zYWZlIjpmYWxzZSwiSXNBcnJheSI6ZmFsc2UsIklzTWFwIjpmYWxzZX19LCIyIjp7ImlkIjoyLCJOYW1lIjoidmVyc2lvbiIsIlR5cGUiOnsiVG9rZW5UeXBlIjoxNCwiTWFwS2V5VHlwZSI6bnVsbCwiQ2hpbGRUeXBlIjpudWxsLCJjdHJOYW1lIjoiIiwiSXNVbnNhZmUiOmZhbHNlLCJJc0FycmF5IjpmYWxzZSwiSXNNYXAiOmZhbHNlfX0sIjMiOnsiaWQiOjMsIk5hbWUiOiJjcmVhdGVkQXQiLCJUeXBlIjp7IlRva2VuVHlwZSI6MTEsIk1hcEtleVR5cGUiOm51bGwsIkNoaWxkVHlwZSI6bnVsbCwiY3RyTmFtZSI6IiIsIklzVW5zYWZlIjpmYWxzZSwiSXNBcnJheSI6ZmFsc2UsIklzTWFwIjpmYWxzZX19fX19fQ== [meta_e] \ No newline at end of file diff --git a/schemas/gen.bat b/schemas/gen.bat new file mode 100644 index 0000000..1feb354 --- /dev/null +++ b/schemas/gen.bat @@ -0,0 +1,3 @@ +bencgen --in client_root_key.benc --out ../cert --file rootkey --lang go +bencgen --in handshake_packet.benc --out ../handshake --file packet --lang go +bencgen --in server_certificate.benc --out ../cert --file servercert --lang go \ No newline at end of file diff --git a/schemas/handshake_packet.benc b/schemas/handshake_packet.benc new file mode 100644 index 0000000..a8da8d1 --- /dev/null +++ b/schemas/handshake_packet.benc @@ -0,0 +1,18 @@ +define handshake; + +var go_package = "github.com/deneonet/knet/handshake"; + +enum PacketType { + CertificateRequest, + CertificateResponse, + ServerVerification, + ClientInformation +} + +ctr Packet { + bytes payload = 1; + PacketType type = 2; +} + +# DO NOT EDIT. +# [meta_s] eyJtc2dzIjp7IlBhY2tldCI6eyJySWRzIjpudWxsLCJmaWVsZHMiOnsiMSI6eyJpZCI6MSwiTmFtZSI6InBheWxvYWQiLCJUeXBlIjp7IlRva2VuVHlwZSI6MTksIk1hcEtleVR5cGUiOm51bGwsIkNoaWxkVHlwZSI6bnVsbCwiY3RyTmFtZSI6IiIsIklzVW5zYWZlIjpmYWxzZSwiSXNBcnJheSI6ZmFsc2UsIklzTWFwIjpmYWxzZX19LCIyIjp7ImlkIjoyLCJOYW1lIjoidHlwZSIsIlR5cGUiOnsiVG9rZW5UeXBlIjowLCJNYXBLZXlUeXBlIjpudWxsLCJDaGlsZFR5cGUiOm51bGwsImN0ck5hbWUiOiJQYWNrZXRUeXBlIiwiSXNVbnNhZmUiOmZhbHNlLCJJc0FycmF5IjpmYWxzZSwiSXNNYXAiOmZhbHNlfX19fX19 [meta_e] \ No newline at end of file diff --git a/schemas/server_certificate.benc b/schemas/server_certificate.benc new file mode 100644 index 0000000..f386bb3 --- /dev/null +++ b/schemas/server_certificate.benc @@ -0,0 +1,16 @@ +define cert; + +var go_package = "github.com/deneonet/knet/cert"; + +ctr ServerCertificate { + bytes publicKey = 1; + bytes privateKey = 2; + + bytes publicKeySignature = 3; + + int version = 4; + int64 createdAt = 5; +} + +# DO NOT EDIT. +# [meta_s] eyJtc2dzIjp7IlNlcnZlckNlcnRpZmljYXRlIjp7InJJZHMiOm51bGwsImZpZWxkcyI6eyIxIjp7ImlkIjoxLCJOYW1lIjoicHVibGljS2V5IiwiVHlwZSI6eyJUb2tlblR5cGUiOjE5LCJNYXBLZXlUeXBlIjpudWxsLCJDaGlsZFR5cGUiOm51bGwsImN0ck5hbWUiOiIiLCJJc1Vuc2FmZSI6ZmFsc2UsIklzQXJyYXkiOmZhbHNlLCJJc01hcCI6ZmFsc2V9fSwiMiI6eyJpZCI6MiwiTmFtZSI6InByaXZhdGVLZXkiLCJUeXBlIjp7IlRva2VuVHlwZSI6MTksIk1hcEtleVR5cGUiOm51bGwsIkNoaWxkVHlwZSI6bnVsbCwiY3RyTmFtZSI6IiIsIklzVW5zYWZlIjpmYWxzZSwiSXNBcnJheSI6ZmFsc2UsIklzTWFwIjpmYWxzZX19LCIzIjp7ImlkIjozLCJOYW1lIjoicHVibGljS2V5U2lnbmF0dXJlIiwiVHlwZSI6eyJUb2tlblR5cGUiOjE5LCJNYXBLZXlUeXBlIjpudWxsLCJDaGlsZFR5cGUiOm51bGwsImN0ck5hbWUiOiIiLCJJc1Vuc2FmZSI6ZmFsc2UsIklzQXJyYXkiOmZhbHNlLCJJc01hcCI6ZmFsc2V9fSwiNCI6eyJpZCI6NCwiTmFtZSI6InZlcnNpb24iLCJUeXBlIjp7IlRva2VuVHlwZSI6MTQsIk1hcEtleVR5cGUiOm51bGwsIkNoaWxkVHlwZSI6bnVsbCwiY3RyTmFtZSI6IiIsIklzVW5zYWZlIjpmYWxzZSwiSXNBcnJheSI6ZmFsc2UsIklzTWFwIjpmYWxzZX19LCI1Ijp7ImlkIjo1LCJOYW1lIjoiY3JlYXRlZEF0IiwiVHlwZSI6eyJUb2tlblR5cGUiOjExLCJNYXBLZXlUeXBlIjpudWxsLCJDaGlsZFR5cGUiOm51bGwsImN0ck5hbWUiOiIiLCJJc1Vuc2FmZSI6ZmFsc2UsIklzQXJyYXkiOmZhbHNlLCJJc01hcCI6ZmFsc2V9fX19fX0= [meta_e] \ No newline at end of file diff --git a/server.go b/server.go index 47e8c17..13b1f01 100644 --- a/server.go +++ b/server.go @@ -3,55 +3,49 @@ package knet import ( "crypto/ecdh" "crypto/sha256" - "fmt" "io" "net" "os" + "slices" "sync" "time" - "kinetra.de/net/cert" - "kinetra.de/net/crypto" - "kinetra.de/net/netutils" - "kinetra.de/net/packets" + "github.com/deneonet/knet/cert" + "github.com/deneonet/knet/crypto" + "github.com/deneonet/knet/handshake" + "github.com/deneonet/knet/netutils" bstd "github.com/deneonet/benc/std" ) type Server struct { - // Addr is the address the server will bind to. - Addr string - - // Cert is the file path to the server's certificate. - CertFile string - cert cert.Certificate - priv *ecdh.PrivateKey - - sessions map[net.Conn]ServerSession - - mutex *sync.Mutex - - ReadDeadline time.Duration - WriteDeadline time.Duration - + Addr string + CertFile string + cert cert.ServerCertificate + priv *ecdh.PrivateKey + sessions map[net.Conn]ServerSession + mutex sync.RWMutex + netUtilsSettings netutils.NetUtilsSettings + ReadDeadline time.Duration + WriteDeadline time.Duration HandshakeReadDeadline time.Duration HandshakeWriteDeadline time.Duration - - EnableConnPurge bool - ConnPurgeInterval time.Duration - - IdleTimeout time.Duration - BufferSize int - - OnRead func(net.Conn, ReadInfo) AfterAction - OnDisconnect func(net.Conn) - OnSecureConnect func(net.Conn, ServerSession) AfterAction + EnableConnPurge bool + ConnPurgeInterval time.Duration + MinSessionsBeforePurge int + IdleTimeout time.Duration + BufferSize int + OnRead func(net.Conn, ReadInfo) AfterAction + OnDisconnect func(net.Conn) + OnSecureConnect func(net.Conn, ServerSession) AfterAction + OnAcceptingError func(error) bool + OnConnectionError func(net.Conn, error) AfterAction } type ServerSession struct { SharedSecret []byte - // TODO: LastActivity time.Time - // TODO: Data map[string]interface{} + LastActivity time.Time + Data map[string]interface{} } type ServerHandshakeResult byte @@ -62,44 +56,46 @@ const ( ServerHandshakeError ) -/*TOOD: func (s *Server) connectionPurge() { +type connectionErrorResult byte + +const ( + connectionErrorContinue connectionErrorResult = iota + connectionErrorReturn + connectionErrorMoveOn +) + +func (s *Server) connectionPurge() { for { time.Sleep(s.ConnPurgeInterval) - s.mutex.Lock() - if proto.IsDebugMode() { - fmt.Println("Start connection purge") + if len(s.sessions) < s.MinSessionsBeforePurge { + continue } - currentTime := time.Now() - for conn, session := range s.ServerSessions { - if currentTime.Sub(session.LastActivity) > s.IdleTimeout { - // Connection has been idle for too long, close it. + + s.mutex.Lock() + + now := time.Now() + for conn, session := range s.sessions { + if now.Sub(session.LastActivity) > s.IdleTimeout { conn.Close() - delete(s.ServerSessions, conn) + delete(s.sessions, conn) } } - if proto.IsDebugMode() { - fmt.Println("Finished connection purge") - } + s.mutex.Unlock() } -}*/ - -func (s *Server) setDeadline(conn net.Conn, handshakeComplete bool) { - conn.SetDeadline(time.Time{}) +} - var readDeadline, writeDeadline time.Duration - if handshakeComplete { - readDeadline, writeDeadline = s.ReadDeadline, s.WriteDeadline - } else { - readDeadline, writeDeadline = s.HandshakeReadDeadline, s.HandshakeWriteDeadline +func (s *Server) handleConnectionError(conn net.Conn, err error) connectionErrorResult { + if err == nil { + return connectionErrorMoveOn } - if readDeadline > 0 { - conn.SetReadDeadline(time.Now().Add(readDeadline)) - } - if writeDeadline > 0 { - conn.SetWriteDeadline(time.Now().Add(writeDeadline)) + action := s.OnConnectionError(conn, err) + if action == Close { + return connectionErrorReturn } + + return connectionErrorContinue } func (s *Server) handleConnection(conn net.Conn) { @@ -110,29 +106,32 @@ func (s *Server) handleConnection(conn net.Conn) { handshakeComplete := false buf := make([]byte, s.BufferSize) - var session ServerSession + for { - s.setDeadline(conn, handshakeComplete) - size, err := netutils.ReadFromConn(conn, buf) + size, err := netutils.ReadFromConn(conn, buf, &s.netUtilsSettings) + if int(size) > len(buf) { - panic("data exceeded buffer size") - break + if result := s.handleConnectionError(conn, ErrDataExceededBufferSize); result == connectionErrorReturn { + return + } + continue } if err != nil && !handshakeComplete { - //TODO: error handling - panic(err) - break + if result := s.handleConnectionError(conn, err); result == connectionErrorReturn { + return + } + continue } - var result ServerHandshakeResult if !handshakeComplete { - result, err = s.processHandshake(conn, buf[4:size]) + result, err := s.processHandshake(conn, buf[4:size]) if err != nil { - //TODO: error handling - panic(err) - break + if result := s.handleConnectionError(conn, err); result == connectionErrorReturn { + return + } + continue } if result == ServerContinueHandshake { @@ -141,15 +140,12 @@ func (s *Server) handleConnection(conn net.Conn) { if result == ServerHandshakeComplete { handshakeComplete = true - size = 0 - session, _ = s.GetSession(conn) - s.setDeadline(conn, handshakeComplete) - if s.OnSecureConnect != nil { - if action := s.OnSecureConnect(conn, session); action == Close { - break - } + s.netUtilsSettings.HandshakeCompleted = true + + if s.OnSecureConnect != nil && s.OnSecureConnect(conn, session) == Close { + break } continue @@ -157,13 +153,9 @@ func (s *Server) handleConnection(conn net.Conn) { } if err == io.EOF || size == 0 { - conn.Close() - s.RemoveSession(conn) - if s.OnDisconnect != nil { s.OnDisconnect(conn) } - return } @@ -171,10 +163,13 @@ func (s *Server) handleConnection(conn net.Conn) { continue } - encrypted := buf[4] == 1 - data := buf[5:size] - if encrypted && err == nil { - data, err = crypto.Decrypt(session.SharedSecret, data) + var data []byte = nil + if err == nil { + encrypted := buf[4] == 1 + data = buf[5:size] + if encrypted { + data, err = crypto.Decrypt(session.SharedSecret, data) + } } action := s.OnRead(conn, ReadInfo{ @@ -189,19 +184,18 @@ func (s *Server) handleConnection(conn net.Conn) { } func (s *Server) processHandshake(conn net.Conn, buf []byte) (ServerHandshakeResult, error) { - packet, err := packets.UnmarshalHandshakePacket(buf) + packet, err := handshake.UnmarshalHandshakePacket(buf) if err != nil { return ServerHandshakeError, err } - switch packets.HandshakePacketId(packet.Id) { - case packets.CertificateRequest: - if err = packets.SendCertResponseHandshakePacket(conn, packets.CertificateResponse, s.cert); err != nil { + switch packet.Type { + case handshake.PacketTypeCertificateRequest: + if err = handshake.SendCertResponseHandshakePacket(conn, s.cert, &s.netUtilsSettings); err != nil { return ServerHandshakeError, err } - return ServerContinueHandshake, nil - case packets.ClientInformation: + case handshake.PacketTypeClientInformation: clientPublicKey, err := ecdh.P521().NewPublicKey(packet.Payload) if err != nil { return ServerHandshakeError, err @@ -215,7 +209,8 @@ func (s *Server) processHandshake(conn net.Conn, buf []byte) (ServerHandshakeRes aesSecret := sha256.Sum256(sharedSecret) s.SetSession(conn, ServerSession{ SharedSecret: aesSecret[:], - // TODO: Data: make(map[string]interface{}), + LastActivity: time.Now(), + Data: make(map[string]interface{}), }) verification, err := crypto.Encrypt(aesSecret[:], []byte{1, 2, 3, 4}) @@ -223,7 +218,9 @@ func (s *Server) processHandshake(conn net.Conn, buf []byte) (ServerHandshakeRes return ServerHandshakeError, err } - packets.SendHandshakePacket(conn, packets.ServerVerification, verification) + if err := handshake.SendHandshakePacket(conn, handshake.PacketTypeServerVerification, verification, &s.netUtilsSettings); err != nil { + return ServerHandshakeError, err + } return ServerHandshakeComplete, nil } @@ -237,14 +234,15 @@ func (s *Server) Run() error { } defer listener.Close() - //TODO: if s.EnableConnPurge { - // //go s.connectionPurge() - //} - bytes, err := os.ReadFile(s.CertFile) + if s.EnableConnPurge { + go s.connectionPurge() + } + + certData, err := os.ReadFile(s.CertFile) if err != nil { return err } - if err = s.cert.Unmarshal(bytes); err != nil { + if err = s.cert.Unmarshal(certData); err != nil { return err } @@ -255,23 +253,35 @@ func (s *Server) Run() error { if s.IdleTimeout == 0 { s.IdleTimeout = 30 * time.Minute } + if s.ConnPurgeInterval == 0 { + s.ConnPurgeInterval = 35 * time.Minute + } + if s.sessions == nil { s.sessions = make(map[net.Conn]ServerSession) } - if s.HandshakeReadDeadline == 0 { - s.HandshakeReadDeadline = 500 * time.Millisecond - } - if s.HandshakeWriteDeadline == 0 { - s.HandshakeWriteDeadline = 500 * time.Millisecond + s.netUtilsSettings = netutils.NetUtilsSettings{ + HandshakeReadDeadline: s.HandshakeReadDeadline, + HandshakeWriteDeadline: s.HandshakeWriteDeadline, + ReadDeadline: s.ReadDeadline, + WriteDeadline: s.WriteDeadline, } - s.mutex = &sync.Mutex{} + if s.OnAcceptingError == nil { + s.OnAcceptingError = func(err error) bool { return true } + } + if s.OnConnectionError == nil { + s.OnConnectionError = func(conn net.Conn, err error) AfterAction { panic(conn.RemoteAddr().String() + ": " + err.Error()) } + } for { conn, err := listener.Accept() if err != nil { - fmt.Println("Error accepting connection:", err) + if s.OnAcceptingError(err) { + return err + } + continue } go s.handleConnection(conn) @@ -288,52 +298,52 @@ func (s *Server) Send(conn net.Conn, b []byte) error { if err != nil { return err } - return netutils.SendToConn(conn, len(encrypted)+1, func(n int, b []byte) { b[n] = 1; copy(b[n+1:], encrypted) }) + return netutils.SendToConn(conn, len(encrypted)+1, func(n int, buf []byte) { buf[n] = 1; copy(buf[n+1:], encrypted) }, &s.netUtilsSettings) } func (s *Server) SendUnsecure(conn net.Conn, b []byte) error { - return netutils.SendToConn(conn, len(b)+1, func(n int, buf []byte) { buf[n] = 0; copy(buf[n+1:], b) }) + return netutils.SendToConn(conn, len(b)+1, func(n int, buf []byte) { buf[n] = 0; copy(buf[n+1:], b) }, &s.netUtilsSettings) } -func (s *Server) UnmarshalPacket(b []byte, f func(int, []byte) error) error { - n, id, err := bstd.UnmarshalInt(0, b) +func (s *Server) UnmarshalPacket(buf []byte, f func(int, []byte) error) error { + n, id, err := bstd.UnmarshalInt(0, buf) if err != nil { return err } - return f(id, b[n:]) + return f(id, buf[n:]) } func (s *Server) SendPacket(conn net.Conn, id int, p Packet) error { - b := make([]byte, p.Size()+bstd.SizeInt(id)) - n := bstd.MarshalInt(0, b, id) - p.Marshal(b[n:]) - return s.Send(conn, b) + buf := make([]byte, p.Size()+bstd.SizeInt(id)) + n := bstd.MarshalInt(0, buf, id) + p.Marshal(buf[n:]) + return s.Send(conn, buf) } func (s *Server) SendPacketUnsecure(conn net.Conn, id int, p Packet) error { - b := make([]byte, p.Size()+bstd.SizeInt(id)) - n := bstd.MarshalInt(0, b, id) - p.Marshal(b[n:]) - return s.SendUnsecure(conn, b) + buf := make([]byte, p.Size()+bstd.SizeInt(id)) + n := bstd.MarshalInt(0, buf, id) + p.Marshal(buf[n:]) + return s.SendUnsecure(conn, buf) } func (s *Server) SendPacketToAll(id int, p Packet) error { - b := make([]byte, p.Size()+bstd.SizeInt(id)) - n := bstd.MarshalInt(0, b, id) - p.Marshal(b[n:]) - return s.SendToAll(b) + buf := make([]byte, p.Size()+bstd.SizeInt(id)) + n := bstd.MarshalInt(0, buf, id) + p.Marshal(buf[n:]) + return s.SendToAll(buf) } func (s *Server) SendPacketUnsecureToAll(conn net.Conn, id int, p Packet) error { - b := make([]byte, p.Size()+bstd.SizeInt(id)) - n := bstd.MarshalInt(0, b, id) - p.Marshal(b[n:]) - return s.SendUnsecureToAll(b) + buf := make([]byte, p.Size()+bstd.SizeInt(id)) + n := bstd.MarshalInt(0, buf, id) + p.Marshal(buf[n:]) + return s.SendUnsecureToAll(buf) } -func (s *Server) SendToAll(b []byte) error { +func (s *Server) SendToAll(buf []byte) error { for conn := range s.sessions { - if err := s.Send(conn, b); err != nil { + if err := s.Send(conn, buf); err != nil { return err } } @@ -341,9 +351,9 @@ func (s *Server) SendToAll(b []byte) error { return nil } -func (s *Server) SendUnsecureToAll(b []byte) error { +func (s *Server) SendUnsecureToAll(buf []byte) error { for conn := range s.sessions { - if err := s.SendUnsecure(conn, b); err != nil { + if err := s.SendUnsecure(conn, buf); err != nil { return err } } @@ -351,37 +361,30 @@ func (s *Server) SendUnsecureToAll(b []byte) error { return nil } -/*TODO: -func (s *Server) Store(conn gnet.Conn, key string, value interface{}) error { - ses, ok := s.GetSession(conn) - if !ok { - return ErrSessionNotFound +func (s *Server) SendToAllExcept(buf []byte, conns ...net.Conn) error { + for conn := range s.sessions { + if slices.Contains(conns, conn) { + continue + } + if err := s.Send(conn, buf); err != nil { + return err + } } - s.mutex.Lock() - ses.Data[key] = value - s.mutex.Unlock() - return nil } -func (s *Server) Get(conn gnet.Conn, key string) (interface{}, error) { - ses, ok := s.GetSession(conn) - if !ok { - return nil, ErrSessionNotFound - } - - s.mutex.Lock() - value := ses.Data[key] - s.mutex.Unlock() - - return value, nil -}*/ +func (s *Server) SendPacketToAllExcept(id int, p Packet, conns ...net.Conn) error { + buf := make([]byte, p.Size()+bstd.SizeInt(id)) + n := bstd.MarshalInt(0, buf, id) + p.Marshal(buf[n:]) + return s.SendToAllExcept(buf, conns...) +} func (s *Server) GetSession(conn net.Conn) (ServerSession, bool) { - s.mutex.Lock() + s.mutex.RLock() + defer s.mutex.RUnlock() ses, ok := s.sessions[conn] - s.mutex.Unlock() return ses, ok } @@ -396,3 +399,25 @@ func (s *Server) RemoveSession(conn net.Conn) { delete(s.sessions, conn) s.mutex.Unlock() } + +func (s *Server) Get(conn net.Conn, key string) (interface{}, error) { + s.mutex.RLock() + defer s.mutex.RUnlock() + ses, ok := s.sessions[conn] + if !ok { + return nil, ErrSessionNotFound + } + return ses.Data[key], nil +} + +func (s *Server) Store(conn net.Conn, key string, value interface{}) error { + s.mutex.Lock() + defer s.mutex.Unlock() + ses, ok := s.sessions[conn] + if !ok { + return ErrSessionNotFound + } + ses.Data[key] = value + s.sessions[conn] = ses + return nil +}