Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the shard map parsing so we can pass it down into the API ob… #564

Merged
merged 1 commit into from
Dec 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 16 additions & 35 deletions cmd/rekor-server/app/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,28 @@ import (
"fmt"
"strconv"
"strings"
)

type LogRange struct {
TreeID uint64
TreeLength uint64
}
"github.com/sigstore/rekor/pkg/api"
)

type LogRanges struct {
Ranges []LogRange
type LogRangesFlag struct {
Ranges api.LogRanges
}

func (l *LogRanges) Set(s string) error {
func (l *LogRangesFlag) Set(s string) error {
ranges := strings.Split(s, ",")
l.Ranges = []LogRange{}
l.Ranges = api.LogRanges{}

var err error
inputRanges := []LogRange{}
inputRanges := []api.LogRange{}

// Only go up to the second to last one, the last one is special cased beow
for _, r := range ranges[:len(ranges)-1] {
split := strings.SplitN(r, "=", 2)
if len(split) != 2 {
return fmt.Errorf("invalid range flag, expected two parts separated by an =, got %s", r)
}
lr := LogRange{}
lr := api.LogRange{}
lr.TreeID, err = strconv.ParseUint(split[0], 10, 64)
if err != nil {
return err
Expand All @@ -63,7 +60,7 @@ func (l *LogRanges) Set(s string) error {
return err
}

inputRanges = append(inputRanges, LogRange{
inputRanges = append(inputRanges, api.LogRange{
TreeID: lastTreeID,
})

Expand All @@ -76,36 +73,20 @@ func (l *LogRanges) Set(s string) error {
TreeIDs[lr.TreeID] = struct{}{}
}

l.Ranges = inputRanges
l.Ranges = api.LogRanges{
Ranges: inputRanges,
}
return nil
}

func (l *LogRanges) String() string {
func (l *LogRangesFlag) String() string {
ranges := []string{}
for _, r := range l.Ranges {
for _, r := range l.Ranges.Ranges {
ranges = append(ranges, fmt.Sprintf("%d=%d", r.TreeID, r.TreeLength))
}
return strings.Join(ranges, ",")
}

func (l *LogRanges) Type() string {
return "LogRanges"
}

func (l *LogRanges) ResolveVirtualIndex(index int) (uint64, uint64) {
indexLeft := index
for _, l := range l.Ranges {
if indexLeft < int(l.TreeLength) {
return l.TreeID, uint64(indexLeft)
}
indexLeft -= int(l.TreeLength)
}

// Return the last one!
return l.Ranges[len(l.Ranges)-1].TreeID, uint64(indexLeft)
}

// ActiveIndex returns the active shard index, always the last shard in the range
func (l *LogRanges) ActiveIndex() uint64 {
return l.Ranges[len(l.Ranges)-1].TreeID
func (l *LogRangesFlag) Type() string {
return "LogRangesFlag"
}
55 changes: 8 additions & 47 deletions cmd/rekor-server/app/flags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,20 @@ import (
"testing"

"github.com/google/go-cmp/cmp"
"github.com/sigstore/rekor/pkg/api"
)

func TestLogRanges_Set(t *testing.T) {
tests := []struct {
name string
arg string
want []LogRange
want []api.LogRange
active uint64
}{
{
name: "one, no length",
arg: "1234",
want: []LogRange{
want: []api.LogRange{
{
TreeID: 1234,
TreeLength: 0,
Expand All @@ -42,7 +43,7 @@ func TestLogRanges_Set(t *testing.T) {
{
name: "two",
arg: "1234=10,7234",
want: []LogRange{
want: []api.LogRange{
{
TreeID: 1234,
TreeLength: 10,
Expand All @@ -57,16 +58,16 @@ func TestLogRanges_Set(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l := &LogRanges{}
l := &LogRangesFlag{}
if err := l.Set(tt.arg); err != nil {
t.Errorf("LogRanges.Set() expected no error, got %v", err)
}

if diff := cmp.Diff(tt.want, l.Ranges); diff != "" {
if diff := cmp.Diff(tt.want, l.Ranges.Ranges); diff != "" {
t.Errorf(diff)
}

active := l.ActiveIndex()
active := l.Ranges.ActiveIndex()
if active != tt.active {
t.Errorf("LogRanges.Active() expected %d no error, got %d", tt.active, active)
}
Expand Down Expand Up @@ -94,50 +95,10 @@ func TestLogRanges_SetErr(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l := &LogRanges{}
l := &LogRangesFlag{}
if err := l.Set(tt.arg); err == nil {
t.Error("LogRanges.Set() expected error but got none")
}
})
}
}

func TestLogRanges_ResolveVirtualIndex(t *testing.T) {
lrs := LogRanges{
Ranges: []LogRange{
{TreeID: 1, TreeLength: 17},
{TreeID: 2, TreeLength: 1},
{TreeID: 3, TreeLength: 100},
{TreeID: 4},
},
}

for _, tt := range []struct {
Index int
WantTreeID uint64
WantIndex uint64
}{
{
Index: 3,
WantTreeID: 1, WantIndex: 3,
},
// This is the first (0th) entry in the next tree
{
Index: 17,
WantTreeID: 2, WantIndex: 0,
},
// Overflow
{
Index: 3000,
WantTreeID: 4, WantIndex: 2882,
},
} {
tree, index := lrs.ResolveVirtualIndex(tt.Index)
if tree != tt.WantTreeID {
t.Errorf("LogRanges.ResolveVirtualIndex() tree = %v, want %v", tree, tt.WantTreeID)
}
if index != tt.WantIndex {
t.Errorf("LogRanges.ResolveVirtualIndex() index = %v, want %v", index, tt.WantIndex)
}
}
}
2 changes: 1 addition & 1 deletion cmd/rekor-server/app/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ var (
cfgFile string
logType string
enablePprof bool
logRangeMap LogRanges
logRangeMap LogRangesFlag
)

// rootCmd represents the base command when called without any subcommands
Expand Down
2 changes: 1 addition & 1 deletion cmd/rekor-server/app/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ var serveCmd = &cobra.Command{
server.Port = int(viper.GetUint("port"))
server.EnabledListeners = []string{"http"}

api.ConfigureAPI()
api.ConfigureAPI(logRangeMap.Ranges)
server.ConfigureAPI()

http.Handle("/metrics", promhttp.Handler())
Expand Down
9 changes: 6 additions & 3 deletions pkg/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ func dial(ctx context.Context, rpcServer string) (*grpc.ClientConn, error) {
type API struct {
logClient trillian.TrillianLogClient
logID int64
logRanges *LogRanges
pubkey string // PEM encoded public key
pubkeyHash string // SHA256 hash of DER-encoded public key
signer signature.Signer
Expand All @@ -64,7 +65,7 @@ type API struct {
certChainPem string // PEM encoded timestamping cert chain
}

func NewAPI() (*API, error) {
func NewAPI(ranges LogRanges) (*API, error) {
logRPCServer := fmt.Sprintf("%s:%d",
viper.GetString("trillian_log_server.address"),
viper.GetUint("trillian_log_server.port"))
Expand Down Expand Up @@ -137,6 +138,7 @@ func NewAPI() (*API, error) {
// Transparency Log Stuff
logClient: logClient,
logID: tLogID,
logRanges: &ranges,
// Signing/verifying fields
pubkey: string(pubkey),
pubkeyHash: hex.EncodeToString(pubkeyHashBytes[:]),
Expand All @@ -154,10 +156,11 @@ var (
storageClient storage.AttestationStorage
)

func ConfigureAPI() {
func ConfigureAPI(ranges LogRanges) {
cfg := radix.PoolConfig{}
var err error
api, err = NewAPI()

api, err = NewAPI(ranges)
if err != nil {
log.Logger.Panic(err)
}
Expand Down
43 changes: 43 additions & 0 deletions pkg/api/ranges.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//
// Copyright 2021 The Sigstore Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package api

type LogRanges struct {
Ranges []LogRange
}

type LogRange struct {
TreeID uint64
TreeLength uint64
}

func (l *LogRanges) ResolveVirtualIndex(index int) (uint64, uint64) {
indexLeft := index
for _, l := range l.Ranges {
if indexLeft < int(l.TreeLength) {
return l.TreeID, uint64(indexLeft)
}
indexLeft -= int(l.TreeLength)
}

// Return the last one!
return l.Ranges[len(l.Ranges)-1].TreeID, uint64(indexLeft)
}

// ActiveIndex returns the active shard index, always the last shard in the range
func (l *LogRanges) ActiveIndex() uint64 {
return l.Ranges[len(l.Ranges)-1].TreeID
}
58 changes: 58 additions & 0 deletions pkg/api/ranges_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
//
// Copyright 2021 The Sigstore Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package api

import "testing"

func TestLogRanges_ResolveVirtualIndex(t *testing.T) {
lrs := LogRanges{
Ranges: []LogRange{
{TreeID: 1, TreeLength: 17},
{TreeID: 2, TreeLength: 1},
{TreeID: 3, TreeLength: 100},
{TreeID: 4},
},
}

for _, tt := range []struct {
Index int
WantTreeID uint64
WantIndex uint64
}{
{
Index: 3,
WantTreeID: 1, WantIndex: 3,
},
// This is the first (0th) entry in the next tree
{
Index: 17,
WantTreeID: 2, WantIndex: 0,
},
// Overflow
{
Index: 3000,
WantTreeID: 4, WantIndex: 2882,
},
} {
tree, index := lrs.ResolveVirtualIndex(tt.Index)
if tree != tt.WantTreeID {
t.Errorf("LogRanges.ResolveVirtualIndex() tree = %v, want %v", tree, tt.WantTreeID)
}
if index != tt.WantIndex {
t.Errorf("LogRanges.ResolveVirtualIndex() index = %v, want %v", index, tt.WantIndex)
}
}
}