diff --git a/handlers_test.go b/handlers_test.go index 1d5d4bca2..fb8646655 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -24,6 +24,7 @@ import ( "testing" "time" + "github.com/google/uuid" "github.com/gorilla/mux" "github.com/mozilla-services/autograph/database" @@ -303,6 +304,23 @@ func TestHeartbeat(t *testing.T) { } } +func TestRequestIDWellFormed(t *testing.T) { + // This method of testing middleware is cribbed from + // https://stackoverflow.com/questions/51201056/testing-golang-middleware-that-modifies-the-request + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + val := r.Context().Value(contextKeyRequestID).(string) + if uuid.Validate(val) != nil { + t.Errorf("requestID is not a valid uuid! %v", val) + } + }) + + handlerToTest := setRequestID()(nextHandler) + + req := httptest.NewRequest("GET", "http://foo.bar/__heartbeat__", nil) + + handlerToTest.ServeHTTP(httptest.NewRecorder(), req) +} + func TestHeartbeatChecksHSMStatusFails(t *testing.T) { // NB: do not run in parallel with TestHeartbeat* ag.heartbeatConf = &heartbeatConfig{ diff --git a/middleware.go b/middleware.go index 00217fd4e..5d8e8096c 100644 --- a/middleware.go +++ b/middleware.go @@ -1,9 +1,10 @@ package main import ( - "math/rand" "net/http" "time" + + "github.com/google/uuid" ) // Middleware wraps an http.Handler with additional functionality @@ -27,13 +28,18 @@ func setResponseHeaders() Middleware { func setRequestID() Middleware { return func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - rid := make([]rune, 16) - letters := []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") - for i := range rid { - rid[i] = letters[rand.Intn(len(letters))] + // NewV7 is used instead of New because the latter will panic + // if can't generate a UUID. It's preferably for us to have + // worse request ids than panic. + uuid, err := uuid.NewV7() + var rid string + if err != nil { + rid = "-" + } else { + rid = uuid.String() } - h.ServeHTTP(w, addToContext(r, contextKeyRequestID, string(rid))) + h.ServeHTTP(w, addToContext(r, contextKeyRequestID, rid)) }) } }