diff --git a/pkg/apis/deployment/v1alpha/duration.go b/pkg/apis/deployment/v1alpha/duration.go new file mode 100644 index 000000000..0f68f18bc --- /dev/null +++ b/pkg/apis/deployment/v1alpha/duration.go @@ -0,0 +1,81 @@ +// +// DISCLAIMER +// +// Copyright 2018 ArangoDB GmbH, Cologne, Germany +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copyright holder is ArangoDB GmbH, Cologne, Germany +// +// Author Ewout Prangsma +// + +package v1alpha + +import ( + "time" + + "github.com/pkg/errors" +) + +// Duration is a period of time, specified in go time.Duration format. +// This is intended to allow human friendly TTL's to be specified. +type Duration string + +// Validate the duration. +// Return errors when validation fails, nil on success. +func (d Duration) Validate() error { + if d != "" { + if _, err := time.ParseDuration(string(d)); err != nil { + return maskAny(errors.Wrapf(ValidationError, "Invalid duration: '%s': %s", string(d), err.Error())) + } + } + return nil +} + +// AsDuration parses the duration to a time.Duration value. +// In case of a parse error, 0 is returned. +func (d Duration) AsDuration() time.Duration { + if d == "" { + return 0 + } + result, err := time.ParseDuration(string(d)) + if err != nil { + return 0 + } + return result +} + +// NewDuration returns a reference to a Duration with given value. +func NewDuration(input Duration) *Duration { + return &input +} + +// NewDurationOrNil returns nil if input is nil, otherwise returns a clone of the given value. +func NewDurationOrNil(input *Duration) *Duration { + if input == nil { + return nil + } + return NewDuration(*input) +} + +// DurationOrDefault returns the default value (or empty string) if input is nil, otherwise returns the referenced value. +func DurationOrDefault(input *Duration, defaultValue ...Duration) Duration { + if input == nil { + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" + } + return *input +} diff --git a/pkg/apis/deployment/v1alpha/tls_spec.go b/pkg/apis/deployment/v1alpha/tls_spec.go index 4298f4c60..710c187db 100644 --- a/pkg/apis/deployment/v1alpha/tls_spec.go +++ b/pkg/apis/deployment/v1alpha/tls_spec.go @@ -25,7 +25,6 @@ package v1alpha import ( "fmt" "net" - "time" "github.com/arangodb/kube-arangodb/pkg/util" "github.com/arangodb/kube-arangodb/pkg/util/k8sutil" @@ -33,14 +32,14 @@ import ( ) const ( - defaultTLSTTL = time.Hour * 2160 // About 3 month + defaultTLSTTL = Duration("2610h") // About 3 month ) // TLSSpec holds TLS specific configuration settings type TLSSpec struct { - CASecretName *string `json:"caSecretName,omitempty"` - AltNames []string `json:"altNames,omitempty"` - TTL *time.Duration `json:"ttl,omitempty"` + CASecretName *string `json:"caSecretName,omitempty"` + AltNames []string `json:"altNames,omitempty"` + TTL *Duration `json:"ttl,omitempty"` } const ( @@ -59,8 +58,8 @@ func (s TLSSpec) GetAltNames() []string { } // GetTTL returns the value of ttl. -func (s TLSSpec) GetTTL() time.Duration { - return util.DurationOrDefault(s.TTL) +func (s TLSSpec) GetTTL() Duration { + return DurationOrDefault(s.TTL) } // IsSecure returns true when a CA secret has been set, false otherwise. @@ -94,6 +93,9 @@ func (s TLSSpec) Validate() error { if _, _, _, err := s.GetParsedAltNames(); err != nil { return maskAny(err) } + if err := s.GetTTL().Validate(); err != nil { + return maskAny(err) + } } return nil } @@ -105,10 +107,10 @@ func (s *TLSSpec) SetDefaults(defaultCASecretName string) { // string should result in the default value. s.CASecretName = util.NewString(defaultCASecretName) } - if s.GetTTL() == 0 { + if s.GetTTL() == "" { // Note that we don't check for nil here, since even a specified, but zero // should result in the default value. - s.TTL = util.NewDuration(defaultTLSTTL) + s.TTL = NewDuration(defaultTLSTTL) } } @@ -121,6 +123,6 @@ func (s *TLSSpec) SetDefaultsFrom(source TLSSpec) { s.AltNames = source.AltNames } if s.TTL == nil { - s.TTL = util.NewDurationOrNil(source.TTL) + s.TTL = NewDurationOrNil(source.TTL) } } diff --git a/pkg/apis/deployment/v1alpha/tls_spec_test.go b/pkg/apis/deployment/v1alpha/tls_spec_test.go index 15d006ab8..977a61dd8 100644 --- a/pkg/apis/deployment/v1alpha/tls_spec_test.go +++ b/pkg/apis/deployment/v1alpha/tls_spec_test.go @@ -62,5 +62,5 @@ func TestTLSSpecSetDefaults(t *testing.T) { assert.Len(t, def(TLSSpec{}).GetAltNames(), 0) assert.Len(t, def(TLSSpec{AltNames: []string{"foo.local"}}).GetAltNames(), 1) assert.Equal(t, defaultTLSTTL, def(TLSSpec{}).GetTTL()) - assert.Equal(t, time.Hour, def(TLSSpec{TTL: util.NewDuration(time.Hour)}).GetTTL()) + assert.Equal(t, time.Hour, def(TLSSpec{TTL: NewDuration("1h")}).GetTTL().AsDuration()) } diff --git a/pkg/apis/deployment/v1alpha/zz_generated.deepcopy.go b/pkg/apis/deployment/v1alpha/zz_generated.deepcopy.go index 44dd343cd..030f9f26c 100644 --- a/pkg/apis/deployment/v1alpha/zz_generated.deepcopy.go +++ b/pkg/apis/deployment/v1alpha/zz_generated.deepcopy.go @@ -627,7 +627,7 @@ func (in *TLSSpec) DeepCopyInto(out *TLSSpec) { if *in == nil { *out = nil } else { - *out = new(time.Duration) + *out = new(Duration) **out = **in } } diff --git a/pkg/deployment/resources/tls.go b/pkg/deployment/resources/tls.go index b3d1d4da9..725fa0c40 100644 --- a/pkg/deployment/resources/tls.go +++ b/pkg/deployment/resources/tls.go @@ -105,7 +105,7 @@ func createServerCertificate(log zerolog.Logger, cli v1.CoreV1Interface, serverN Hosts: append(append(serverNames, dnsNames...), ipAddresses...), EmailAddresses: emailAddress, ValidFrom: time.Now(), - ValidFor: spec.GetTTL(), + ValidFor: spec.GetTTL().AsDuration(), IsCA: false, ECDSACurve: tlsECDSACurve, } diff --git a/tests/scale_test.go b/tests/scale_test.go index 70f78e834..9c77b0a47 100644 --- a/tests/scale_test.go +++ b/tests/scale_test.go @@ -3,7 +3,6 @@ package tests import ( "context" "testing" - "time" "github.com/dchest/uniuri" @@ -24,7 +23,7 @@ func TestScaleClusterNonTLS(t *testing.T) { // Prepare deployment config depl := newDeployment("test-scale-non-tls" + uniuri.NewLen(4)) depl.Spec.Mode = api.NewMode(api.DeploymentModeCluster) - depl.Spec.TLS = api.TLSSpec{util.NewString("None"), nil, util.NewDuration(time.Second * 50)} + depl.Spec.TLS = api.TLSSpec{CASecretName: util.NewString("None")} depl.Spec.SetDefaults(depl.GetName()) // this must be last // Create deployment