diff --git a/cmd/rekor-server/app/flags.go b/cmd/rekor-server/app/flags.go index 76705f285..f20412ffc 100644 --- a/cmd/rekor-server/app/flags.go +++ b/cmd/rekor-server/app/flags.go @@ -19,23 +19,20 @@ import ( "fmt" "strconv" "strings" -) -type LogRange struct { - TreeID uint64 - TreeLength uint64 -} + "github.com/sigstore/rekor/pkg/api" +) -type LogRanges struct { - Ranges []LogRange +type LogRangesFlag struct { + Ranges api.LogRanges } -func (l *LogRanges) Set(s string) error { +func (l *LogRangesFlag) Set(s string) error { ranges := strings.Split(s, ",") - l.Ranges = []LogRange{} + l.Ranges = api.LogRanges{} var err error - inputRanges := []LogRange{} + inputRanges := []api.LogRange{} // Only go up to the second to last one, the last one is special cased beow for _, r := range ranges[:len(ranges)-1] { @@ -43,7 +40,7 @@ func (l *LogRanges) Set(s string) error { if len(split) != 2 { return fmt.Errorf("invalid range flag, expected two parts separated by an =, got %s", r) } - lr := LogRange{} + lr := api.LogRange{} lr.TreeID, err = strconv.ParseUint(split[0], 10, 64) if err != nil { return err @@ -63,7 +60,7 @@ func (l *LogRanges) Set(s string) error { return err } - inputRanges = append(inputRanges, LogRange{ + inputRanges = append(inputRanges, api.LogRange{ TreeID: lastTreeID, }) @@ -76,36 +73,20 @@ func (l *LogRanges) Set(s string) error { TreeIDs[lr.TreeID] = struct{}{} } - l.Ranges = inputRanges + l.Ranges = api.LogRanges{ + Ranges: inputRanges, + } return nil } -func (l *LogRanges) String() string { +func (l *LogRangesFlag) String() string { ranges := []string{} - for _, r := range l.Ranges { + for _, r := range l.Ranges.Ranges { ranges = append(ranges, fmt.Sprintf("%d=%d", r.TreeID, r.TreeLength)) } return strings.Join(ranges, ",") } -func (l *LogRanges) Type() string { - return "LogRanges" -} - -func (l *LogRanges) ResolveVirtualIndex(index int) (uint64, uint64) { - indexLeft := index - for _, l := range l.Ranges { - if indexLeft < int(l.TreeLength) { - return l.TreeID, uint64(indexLeft) - } - indexLeft -= int(l.TreeLength) - } - - // Return the last one! - return l.Ranges[len(l.Ranges)-1].TreeID, uint64(indexLeft) -} - -// ActiveIndex returns the active shard index, always the last shard in the range -func (l *LogRanges) ActiveIndex() uint64 { - return l.Ranges[len(l.Ranges)-1].TreeID +func (l *LogRangesFlag) Type() string { + return "LogRangesFlag" } diff --git a/cmd/rekor-server/app/flags_test.go b/cmd/rekor-server/app/flags_test.go index ab82b261e..90fa8b4ce 100644 --- a/cmd/rekor-server/app/flags_test.go +++ b/cmd/rekor-server/app/flags_test.go @@ -19,19 +19,20 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/sigstore/rekor/pkg/api" ) func TestLogRanges_Set(t *testing.T) { tests := []struct { name string arg string - want []LogRange + want []api.LogRange active uint64 }{ { name: "one, no length", arg: "1234", - want: []LogRange{ + want: []api.LogRange{ { TreeID: 1234, TreeLength: 0, @@ -42,7 +43,7 @@ func TestLogRanges_Set(t *testing.T) { { name: "two", arg: "1234=10,7234", - want: []LogRange{ + want: []api.LogRange{ { TreeID: 1234, TreeLength: 10, @@ -57,16 +58,16 @@ func TestLogRanges_Set(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - l := &LogRanges{} + l := &LogRangesFlag{} if err := l.Set(tt.arg); err != nil { t.Errorf("LogRanges.Set() expected no error, got %v", err) } - if diff := cmp.Diff(tt.want, l.Ranges); diff != "" { + if diff := cmp.Diff(tt.want, l.Ranges.Ranges); diff != "" { t.Errorf(diff) } - active := l.ActiveIndex() + active := l.Ranges.ActiveIndex() if active != tt.active { t.Errorf("LogRanges.Active() expected %d no error, got %d", tt.active, active) } @@ -94,50 +95,10 @@ func TestLogRanges_SetErr(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - l := &LogRanges{} + l := &LogRangesFlag{} if err := l.Set(tt.arg); err == nil { t.Error("LogRanges.Set() expected error but got none") } }) } } - -func TestLogRanges_ResolveVirtualIndex(t *testing.T) { - lrs := LogRanges{ - Ranges: []LogRange{ - {TreeID: 1, TreeLength: 17}, - {TreeID: 2, TreeLength: 1}, - {TreeID: 3, TreeLength: 100}, - {TreeID: 4}, - }, - } - - for _, tt := range []struct { - Index int - WantTreeID uint64 - WantIndex uint64 - }{ - { - Index: 3, - WantTreeID: 1, WantIndex: 3, - }, - // This is the first (0th) entry in the next tree - { - Index: 17, - WantTreeID: 2, WantIndex: 0, - }, - // Overflow - { - Index: 3000, - WantTreeID: 4, WantIndex: 2882, - }, - } { - tree, index := lrs.ResolveVirtualIndex(tt.Index) - if tree != tt.WantTreeID { - t.Errorf("LogRanges.ResolveVirtualIndex() tree = %v, want %v", tree, tt.WantTreeID) - } - if index != tt.WantIndex { - t.Errorf("LogRanges.ResolveVirtualIndex() index = %v, want %v", index, tt.WantIndex) - } - } -} diff --git a/cmd/rekor-server/app/root.go b/cmd/rekor-server/app/root.go index 92aa60bdc..1ac675705 100644 --- a/cmd/rekor-server/app/root.go +++ b/cmd/rekor-server/app/root.go @@ -33,7 +33,7 @@ var ( cfgFile string logType string enablePprof bool - logRangeMap LogRanges + logRangeMap LogRangesFlag ) // rootCmd represents the base command when called without any subcommands diff --git a/cmd/rekor-server/app/serve.go b/cmd/rekor-server/app/serve.go index ae13a6367..768617fe6 100644 --- a/cmd/rekor-server/app/serve.go +++ b/cmd/rekor-server/app/serve.go @@ -102,7 +102,7 @@ var serveCmd = &cobra.Command{ server.Port = int(viper.GetUint("port")) server.EnabledListeners = []string{"http"} - api.ConfigureAPI() + api.ConfigureAPI(logRangeMap.Ranges) server.ConfigureAPI() http.Handle("/metrics", promhttp.Handler()) diff --git a/pkg/api/api.go b/pkg/api/api.go index 10a548581..5211adede 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -56,6 +56,7 @@ func dial(ctx context.Context, rpcServer string) (*grpc.ClientConn, error) { type API struct { logClient trillian.TrillianLogClient logID int64 + logRanges *LogRanges pubkey string // PEM encoded public key pubkeyHash string // SHA256 hash of DER-encoded public key signer signature.Signer @@ -64,7 +65,7 @@ type API struct { certChainPem string // PEM encoded timestamping cert chain } -func NewAPI() (*API, error) { +func NewAPI(ranges LogRanges) (*API, error) { logRPCServer := fmt.Sprintf("%s:%d", viper.GetString("trillian_log_server.address"), viper.GetUint("trillian_log_server.port")) @@ -137,6 +138,7 @@ func NewAPI() (*API, error) { // Transparency Log Stuff logClient: logClient, logID: tLogID, + logRanges: &ranges, // Signing/verifying fields pubkey: string(pubkey), pubkeyHash: hex.EncodeToString(pubkeyHashBytes[:]), @@ -154,10 +156,11 @@ var ( storageClient storage.AttestationStorage ) -func ConfigureAPI() { +func ConfigureAPI(ranges LogRanges) { cfg := radix.PoolConfig{} var err error - api, err = NewAPI() + + api, err = NewAPI(ranges) if err != nil { log.Logger.Panic(err) } diff --git a/pkg/api/ranges.go b/pkg/api/ranges.go new file mode 100644 index 000000000..9b30e8469 --- /dev/null +++ b/pkg/api/ranges.go @@ -0,0 +1,43 @@ +// +// Copyright 2021 The Sigstore Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +type LogRanges struct { + Ranges []LogRange +} + +type LogRange struct { + TreeID uint64 + TreeLength uint64 +} + +func (l *LogRanges) ResolveVirtualIndex(index int) (uint64, uint64) { + indexLeft := index + for _, l := range l.Ranges { + if indexLeft < int(l.TreeLength) { + return l.TreeID, uint64(indexLeft) + } + indexLeft -= int(l.TreeLength) + } + + // Return the last one! + return l.Ranges[len(l.Ranges)-1].TreeID, uint64(indexLeft) +} + +// ActiveIndex returns the active shard index, always the last shard in the range +func (l *LogRanges) ActiveIndex() uint64 { + return l.Ranges[len(l.Ranges)-1].TreeID +} diff --git a/pkg/api/ranges_test.go b/pkg/api/ranges_test.go new file mode 100644 index 000000000..aad6a662f --- /dev/null +++ b/pkg/api/ranges_test.go @@ -0,0 +1,58 @@ +// +// Copyright 2021 The Sigstore Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import "testing" + +func TestLogRanges_ResolveVirtualIndex(t *testing.T) { + lrs := LogRanges{ + Ranges: []LogRange{ + {TreeID: 1, TreeLength: 17}, + {TreeID: 2, TreeLength: 1}, + {TreeID: 3, TreeLength: 100}, + {TreeID: 4}, + }, + } + + for _, tt := range []struct { + Index int + WantTreeID uint64 + WantIndex uint64 + }{ + { + Index: 3, + WantTreeID: 1, WantIndex: 3, + }, + // This is the first (0th) entry in the next tree + { + Index: 17, + WantTreeID: 2, WantIndex: 0, + }, + // Overflow + { + Index: 3000, + WantTreeID: 4, WantIndex: 2882, + }, + } { + tree, index := lrs.ResolveVirtualIndex(tt.Index) + if tree != tt.WantTreeID { + t.Errorf("LogRanges.ResolveVirtualIndex() tree = %v, want %v", tree, tt.WantTreeID) + } + if index != tt.WantIndex { + t.Errorf("LogRanges.ResolveVirtualIndex() index = %v, want %v", index, tt.WantIndex) + } + } +}