diff --git a/binary_transparency/firmware/internal/ftmap/aggregate.go b/binary_transparency/firmware/internal/ftmap/aggregate.go index 92cf99d05..571cf6b43 100644 --- a/binary_transparency/firmware/internal/ftmap/aggregate.go +++ b/binary_transparency/firmware/internal/ftmap/aggregate.go @@ -30,6 +30,8 @@ import ( func init() { beam.RegisterFunction(aggregationFn) + beam.RegisterFunction(annotationLogIndexFn) + beam.RegisterFunction(logEntryIndexFn) beam.RegisterType(reflect.TypeOf((*api.AggregatedFirmware)(nil)).Elem()) beam.RegisterType(reflect.TypeOf((*aggregatedFirmwareHashFn)(nil)).Elem()) } @@ -40,14 +42,18 @@ func init() { // - AnnotationMalware: `Good` is true providing there are no malware annotations that claim the // firmware is bad. func Aggregate(s beam.Scope, treeID int64, fws, annotationMalwares beam.PCollection) (beam.PCollection, beam.PCollection) { - keyedFws := beam.ParDo(s, func(l *firmwareLogEntry) (uint64, *firmwareLogEntry) { return uint64(l.Index), l }, fws) - keyedAnns := beam.ParDo(s, func(a *annotationMalwareLogEntry) (uint64, *annotationMalwareLogEntry) { - return a.Annotation.FirmwareID.LogIndex, a - }, annotationMalwares) + keyedFws := beam.ParDo(s, logEntryIndexFn, fws) + keyedAnns := beam.ParDo(s, annotationLogIndexFn, annotationMalwares) annotations := beam.ParDo(s, aggregationFn, beam.CoGroupByKey(s, keyedFws, keyedAnns)) return beam.ParDo(s, &aggregatedFirmwareHashFn{treeID}, annotations), annotations } +func logEntryIndexFn(l *firmwareLogEntry) (uint64, *firmwareLogEntry) { return uint64(l.Index), l } + +func annotationLogIndexFn(a *annotationMalwareLogEntry) (uint64, *annotationMalwareLogEntry) { + return a.Annotation.FirmwareID.LogIndex, a +} + func aggregationFn(fwIndex uint64, fwit func(**firmwareLogEntry) bool, amit func(**annotationMalwareLogEntry) bool) (*api.AggregatedFirmware, error) { // There will be exactly one firmware entry for the log index. var fwle *firmwareLogEntry diff --git a/binary_transparency/firmware/internal/ftmap/aggregate_test.go b/binary_transparency/firmware/internal/ftmap/aggregate_test.go index 22ed79c7e..01f6e2612 100644 --- a/binary_transparency/firmware/internal/ftmap/aggregate_test.go +++ b/binary_transparency/firmware/internal/ftmap/aggregate_test.go @@ -19,11 +19,20 @@ import ( "testing" "github.com/apache/beam/sdks/v2/go/pkg/beam" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest" "github.com/google/trillian-examples/binary_transparency/firmware/api" ) +func init() { + register.Function1x1(testAggregationToStringFn) +} + +func TestMain(m *testing.M) { + ptest.Main(m) +} + func TestAggregate(t *testing.T) { fwEntries := []*firmwareLogEntry{ {Index: 0, Firmware: createFW("dummy", 400)}, @@ -104,8 +113,7 @@ func TestAggregate(t *testing.T) { passert.Count(s, entries, "entries", len(fwEntries)) passert.Count(s, aggs, "aggs", len(fwEntries)) - aggregationToString := func(a *api.AggregatedFirmware) string { return fmt.Sprintf("%d: %t", a.Index, a.Good) } - passert.Equals(s, beam.ParDo(s, aggregationToString, aggs), beam.CreateList(s, test.wantGood)) + passert.Equals(s, beam.ParDo(s, testAggregationToStringFn, aggs), beam.CreateList(s, test.wantGood)) err := ptest.Run(p) if err != nil { @@ -114,3 +122,7 @@ func TestAggregate(t *testing.T) { }) } } + +func testAggregationToStringFn(a *api.AggregatedFirmware) string { + return fmt.Sprintf("%d: %t", a.Index, a.Good) +} diff --git a/binary_transparency/firmware/internal/ftmap/log.go b/binary_transparency/firmware/internal/ftmap/log.go index 4bfb1dabd..053c1e1a6 100644 --- a/binary_transparency/firmware/internal/ftmap/log.go +++ b/binary_transparency/firmware/internal/ftmap/log.go @@ -32,6 +32,7 @@ import ( ) func init() { + beam.RegisterFunction(logEntryDeviceIDFn) beam.RegisterFunction(makeDeviceReleaseLogFn) beam.RegisterType(reflect.TypeOf((*moduleLogHashFn)(nil)).Elem()) beam.RegisterType(reflect.TypeOf((*api.DeviceReleaseLog)(nil)).Elem()) @@ -44,11 +45,15 @@ func init() { // 1. the first is of type Entry; the key/value data to include in the map // 2. the second is of type DeviceReleaseLog. func MakeReleaseLogs(s beam.Scope, treeID int64, logEntries beam.PCollection) (beam.PCollection, beam.PCollection) { - keyed := beam.ParDo(s, func(l *firmwareLogEntry) (string, *firmwareLogEntry) { return l.Firmware.DeviceID, l }, logEntries) + keyed := beam.ParDo(s, logEntryDeviceIDFn, logEntries) logs := beam.ParDo(s, makeDeviceReleaseLogFn, beam.GroupByKey(s, keyed)) return beam.ParDo(s, &moduleLogHashFn{TreeID: treeID}, logs), logs } +func logEntryDeviceIDFn(l *firmwareLogEntry) (string, *firmwareLogEntry) { + return l.Firmware.DeviceID, l +} + type moduleLogHashFn struct { TreeID int64 diff --git a/binary_transparency/firmware/internal/ftmap/pipeline_test.go b/binary_transparency/firmware/internal/ftmap/pipeline_test.go index 2ea1bf588..d4c8a5817 100644 --- a/binary_transparency/firmware/internal/ftmap/pipeline_test.go +++ b/binary_transparency/firmware/internal/ftmap/pipeline_test.go @@ -21,12 +21,18 @@ import ( "testing" "github.com/apache/beam/sdks/v2/go/pkg/beam" + "github.com/apache/beam/sdks/v2/go/pkg/beam/register" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest" "github.com/google/trillian-examples/binary_transparency/firmware/api" "github.com/google/trillian/experimental/batchmap" ) +func init() { + register.Function1x1(testLogToStringFn) + register.Function1x1(testRootToStringFn) +} + func TestCreate(t *testing.T) { tests := []struct { name string @@ -74,11 +80,9 @@ func TestCreate(t *testing.T) { t.Fatalf("failed to Create(): %v", err) } - rootToString := func(t *batchmap.Tile) string { return fmt.Sprintf("%x", t.RootHash) } - passert.Equals(s, beam.ParDo(s, rootToString, result.MapTiles), test.wantRoot) + passert.Equals(s, beam.ParDo(s, testRootToStringFn, result.MapTiles), test.wantRoot) - logToString := func(l *api.DeviceReleaseLog) string { return fmt.Sprintf("%s: %v", l.DeviceID, l.Revisions) } - passert.Equals(s, beam.ParDo(s, logToString, result.DeviceLogs), beam.CreateList(s, test.wantLogs)) + passert.Equals(s, beam.ParDo(s, testLogToStringFn, result.DeviceLogs), beam.CreateList(s, test.wantLogs)) err = ptest.Run(p) if err != nil { @@ -87,6 +91,10 @@ func TestCreate(t *testing.T) { }) } } +func testRootToStringFn(t *batchmap.Tile) string { return fmt.Sprintf("%x", t.RootHash) } +func testLogToStringFn(l *api.DeviceReleaseLog) string { + return fmt.Sprintf("%s: %v", l.DeviceID, l.Revisions) +} func createFW(device string, revision uint64) api.FirmwareMetadata { image := fmt.Sprintf("this image is the firmware at revision %d for device %s.", revision, device)