diff --git a/pkg/flags/uint32.go b/pkg/flags/uint32.go new file mode 100644 index 00000000000..496730a4549 --- /dev/null +++ b/pkg/flags/uint32.go @@ -0,0 +1,45 @@ +// Copyright 2022 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package flags + +import ( + "flag" + "strconv" +) + +type uint32Value uint32 + +// NewUint32Value creates an uint32 instance with the provided value. +func NewUint32Value(v uint32) *uint32Value { + val := new(uint32Value) + *val = uint32Value(v) + return val +} + +// Set parses a command line uint32 value. +// Implements "flag.Value" interface. +func (i *uint32Value) Set(s string) error { + v, err := strconv.ParseUint(s, 0, 32) + *i = uint32Value(v) + return err +} + +func (i *uint32Value) String() string { return strconv.FormatUint(uint64(*i), 10) } + +// Uint32FromFlag return the uint32 value of a flag with the given name +func Uint32FromFlag(fs *flag.FlagSet, name string) uint32 { + val := *fs.Lookup(name).Value.(*uint32Value) + return uint32(val) +} diff --git a/pkg/flags/uint32_test.go b/pkg/flags/uint32_test.go new file mode 100644 index 00000000000..aa7487a2320 --- /dev/null +++ b/pkg/flags/uint32_test.go @@ -0,0 +1,111 @@ +// Copyright 2022 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package flags + +import ( + "flag" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUint32Value(t *testing.T) { + cases := []struct { + name string + s string + expectedVal uint32 + expectError bool + }{ + { + name: "normal uint32 value", + s: "200", + expectedVal: 200, + }, + { + name: "zero value", + s: "0", + expectedVal: 0, + }, + { + name: "negative int value", + s: "-200", + expectError: true, + }, + { + name: "invalid integer value", + s: "invalid", + expectError: true, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var val uint32Value + err := val.Set(tc.s) + + if tc.expectError { + if err == nil { + t.Errorf("Expected failure on parsing uint32 value from %s", tc.s) + } + } else { + if err != nil { + t.Errorf("Unexpected error when parsing %s: %v", tc.s, err) + } + assert.Equal(t, uint32(val), tc.expectedVal) + } + }) + } +} + +func TestUint32FromFlag(t *testing.T) { + const flagName = "max-concurrent-streams" + + cases := []struct { + name string + defaultVal uint32 + arguments []string + expectedVal uint32 + }{ + { + name: "only default value", + defaultVal: 15, + arguments: []string{}, + expectedVal: 15, + }, + { + name: "argument has different value from the default one", + defaultVal: 16, + arguments: []string{"--max-concurrent-streams", "200"}, + expectedVal: 200, + }, + { + name: "argument has the same value from the default one", + defaultVal: 105, + arguments: []string{"--max-concurrent-streams", "105"}, + expectedVal: 105, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + fs := flag.NewFlagSet("etcd", flag.ContinueOnError) + fs.Var(NewUint32Value(tc.defaultVal), flagName, "Maximum concurrent streams that each client can open at a time.") + if err := fs.Parse(tc.arguments); err != nil { + t.Fatalf("Unexpected error: %v\n", err) + } + actualMaxStream := Uint32FromFlag(fs, flagName) + assert.Equal(t, actualMaxStream, tc.expectedVal) + }) + } +}