diff --git a/container.go b/container.go index c0ae935f07..90a7ba91cb 100644 --- a/container.go +++ b/container.go @@ -74,7 +74,7 @@ type Container interface { type ImageBuildInfo interface { BuildOptions() (types.ImageBuildOptions, error) // converts the ImageBuildInfo to a types.ImageBuildOptions GetContext() (io.Reader, error) // the path to the build context - GetDockerfile() string // the relative path to the Dockerfile, including the fileitself + GetDockerfile() string // the relative path to the Dockerfile, including the file itself GetRepo() string // get repo label for image GetTag() string // get tag label for image ShouldPrintBuildLog() bool // allow build log to be printed to stdout @@ -286,34 +286,34 @@ func (c *ContainerRequest) GetBuildArgs() map[string]*string { return c.FromDockerfile.BuildArgs } -// GetDockerfile returns the Dockerfile from the ContainerRequest, defaults to "Dockerfile" +// GetDockerfile returns the Dockerfile from the ContainerRequest, defaults to "Dockerfile". +// Sets FromDockerfile.Dockerfile to the default if blank. func (c *ContainerRequest) GetDockerfile() string { - f := c.FromDockerfile.Dockerfile - if f == "" { - return "Dockerfile" + if c.FromDockerfile.Dockerfile == "" { + c.FromDockerfile.Dockerfile = "Dockerfile" } - return f + return c.FromDockerfile.Dockerfile } -// GetRepo returns the Repo label for image from the ContainerRequest, defaults to UUID +// GetRepo returns the Repo label for image from the ContainerRequest, defaults to UUID. +// Sets FromDockerfile.Repo to the default value if blank. func (c *ContainerRequest) GetRepo() string { - r := c.FromDockerfile.Repo - if r == "" { - return uuid.NewString() + if c.FromDockerfile.Repo == "" { + c.FromDockerfile.Repo = uuid.NewString() } - return strings.ToLower(r) + return strings.ToLower(c.FromDockerfile.Repo) } -// GetTag returns the Tag label for image from the ContainerRequest, defaults to UUID +// GetTag returns the Tag label for image from the ContainerRequest, defaults to UUID. +// Sets FromDockerfile.Tag to the default value if blank. func (c *ContainerRequest) GetTag() string { - t := c.FromDockerfile.Tag - if t == "" { - return uuid.NewString() + if c.FromDockerfile.Tag == "" { + c.FromDockerfile.Tag = uuid.NewString() } - return strings.ToLower(t) + return strings.ToLower(c.FromDockerfile.Tag) } // Deprecated: Testcontainers will detect registry credentials automatically, and it will be removed in the next major release. diff --git a/docker.go b/docker.go index cb003c6630..68b9495589 100644 --- a/docker.go +++ b/docker.go @@ -1069,11 +1069,29 @@ func (p *DockerProvider) CreateContainer(ctx context.Context, req ContainerReque var platform *specs.Platform + defaultHooks := []ContainerLifecycleHooks{ + DefaultLoggingHook(p.Logger), + } + + origLifecycleHooks := req.LifecycleHooks + req.LifecycleHooks = []ContainerLifecycleHooks{ + combineContainerHooks(defaultHooks, req.LifecycleHooks), + } + if req.ShouldBuildImage() { + if err = req.buildingHook(ctx); err != nil { + return nil, err + } + imageName, err = p.BuildImage(ctx, &req) if err != nil { return nil, err } + + req.Image = imageName + if err = req.builtHook(ctx); err != nil { + return nil, err + } } else { for _, is := range req.ImageSubstitutors { modifiedTag, err := is.Substitute(imageName) @@ -1149,13 +1167,12 @@ func (p *DockerProvider) CreateContainer(ctx context.Context, req ContainerReque networkingConfig := &network.NetworkingConfig{} // default hooks include logger hook and pre-create hook - defaultHooks := []ContainerLifecycleHooks{ - DefaultLoggingHook(p.Logger), + defaultHooks = append(defaultHooks, defaultPreCreateHook(p, dockerInput, hostConfig, networkingConfig), defaultCopyFileToContainerHook(req.Files), defaultLogConsumersHook(req.LogConsumerCfg), defaultReadinessHook(), - } + ) // in the case the container needs to access a local port // we need to forward the local port to the container @@ -1171,7 +1188,10 @@ func (p *DockerProvider) CreateContainer(ctx context.Context, req ContainerReque defaultHooks = append(defaultHooks, sshdForwardPortsHook) } - req.LifecycleHooks = []ContainerLifecycleHooks{combineContainerHooks(defaultHooks, req.LifecycleHooks)} + // Combine with the original LifecycleHooks to avoid duplicate logging hooks. + req.LifecycleHooks = []ContainerLifecycleHooks{ + combineContainerHooks(defaultHooks, origLifecycleHooks), + } err = req.creatingHook(ctx) if err != nil { diff --git a/docs/features/creating_container.md b/docs/features/creating_container.md index ec33bdb014..6a87477d0d 100644 --- a/docs/features/creating_container.md +++ b/docs/features/creating_container.md @@ -106,6 +106,8 @@ _Testcontainers for Go_ allows you to define your own lifecycle hooks for better You'll be able to pass multiple lifecycle hooks at the `ContainerRequest` as an array of `testcontainers.ContainerLifecycleHooks`. The `testcontainers.ContainerLifecycleHooks` struct defines the following lifecycle hooks, each of them backed by an array of functions representing the hooks: +* `PreBuilds` - hooks that are executed before the container is built +* `PostBuilds` - hooks that are executed after the container is built * `PreCreates` - hooks that are executed before the container is created * `PostCreates` - hooks that are executed after the container is created * `PreStarts` - hooks that are executed before the container is started diff --git a/lifecycle.go b/lifecycle.go index 57833dafc1..63446f715d 100644 --- a/lifecycle.go +++ b/lifecycle.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "reflect" "strings" "time" @@ -39,6 +40,8 @@ type ContainerHook func(ctx context.Context, ctr Container) error // to modify the container lifecycle. All the container lifecycle hooks except the PreCreates hooks // will be passed to the container once it's created type ContainerLifecycleHooks struct { + PreBuilds []ContainerRequestHook + PostBuilds []ContainerRequestHook PreCreates []ContainerRequestHook PostCreates []ContainerHook PreStarts []ContainerHook @@ -57,6 +60,18 @@ var DefaultLoggingHook = func(logger Logging) ContainerLifecycleHooks { } return ContainerLifecycleHooks{ + PreBuilds: []ContainerRequestHook{ + func(ctx context.Context, req ContainerRequest) error { + logger.Printf("🐳 Building image %s:%s", req.GetRepo(), req.GetTag()) + return nil + }, + }, + PostBuilds: []ContainerRequestHook{ + func(ctx context.Context, req ContainerRequest) error { + logger.Printf("✅ Built image %s", req.Image) + return nil + }, + }, PreCreates: []ContainerRequestHook{ func(ctx context.Context, req ContainerRequest) error { logger.Printf("🐳 Creating container for image %s", req.Image) @@ -284,11 +299,34 @@ var defaultReadinessHook = func() ContainerLifecycleHooks { } } +// buildingHook is a hook that will be called before a container image is built. +func (req ContainerRequest) buildingHook(ctx context.Context) error { + return req.applyLifecycleHooks(func(lifecycleHooks ContainerLifecycleHooks) error { + return lifecycleHooks.Building(ctx)(req) + }) +} + +// builtHook is a hook that will be called after a container image is built. +func (req ContainerRequest) builtHook(ctx context.Context) error { + return req.applyLifecycleHooks(func(lifecycleHooks ContainerLifecycleHooks) error { + return lifecycleHooks.Built(ctx)(req) + }) +} + // creatingHook is a hook that will be called before a container is created. func (req ContainerRequest) creatingHook(ctx context.Context) error { - errs := make([]error, len(req.LifecycleHooks)) - for i, lifecycleHooks := range req.LifecycleHooks { - errs[i] = lifecycleHooks.Creating(ctx)(req) + return req.applyLifecycleHooks(func(lifecycleHooks ContainerLifecycleHooks) error { + return lifecycleHooks.Creating(ctx)(req) + }) +} + +// applyLifecycleHooks calls hook on all LifecycleHooks. +func (req ContainerRequest) applyLifecycleHooks(hook func(lifecycleHooks ContainerLifecycleHooks) error) error { + var errs []error + for _, lifecycleHooks := range req.LifecycleHooks { + if err := hook(lifecycleHooks); err != nil { + errs = append(errs, err) + } } return errors.Join(errs...) @@ -370,9 +408,11 @@ func (c *DockerContainer) terminatedHook(ctx context.Context) error { // applyLifecycleHooks applies all lifecycle hooks reporting the container logs on error if logError is true. func (c *DockerContainer) applyLifecycleHooks(ctx context.Context, logError bool, hooks func(lifecycleHooks ContainerLifecycleHooks) []ContainerHook) error { - errs := make([]error, len(c.lifecycleHooks)) - for i, lifecycleHooks := range c.lifecycleHooks { - errs[i] = containerHookFn(ctx, hooks(lifecycleHooks))(c) + var errs []error + for _, lifecycleHooks := range c.lifecycleHooks { + if err := containerHookFn(ctx, hooks(lifecycleHooks))(c); err != nil { + errs = append(errs, err) + } } if err := errors.Join(errs...); err != nil { @@ -394,10 +434,26 @@ func (c *DockerContainer) applyLifecycleHooks(ctx context.Context, logError bool return nil } +// Building is a hook that will be called before a container image is built. +func (c ContainerLifecycleHooks) Building(ctx context.Context) func(req ContainerRequest) error { + return containerRequestHook(ctx, c.PreBuilds) +} + +// Building is a hook that will be called before a container image is built. +func (c ContainerLifecycleHooks) Built(ctx context.Context) func(req ContainerRequest) error { + return containerRequestHook(ctx, c.PostBuilds) +} + // Creating is a hook that will be called before a container is created. func (c ContainerLifecycleHooks) Creating(ctx context.Context) func(req ContainerRequest) error { + return containerRequestHook(ctx, c.PreCreates) +} + +// containerRequestHook returns a function that will iterate over all +// the hooks and call them one by one until there is an error. +func containerRequestHook(ctx context.Context, hooks []ContainerRequestHook) func(req ContainerRequest) error { return func(req ContainerRequest) error { - for _, hook := range c.PreCreates { + for _, hook := range hooks { if err := hook(ctx, req); err != nil { return err } @@ -411,9 +467,11 @@ func (c ContainerLifecycleHooks) Creating(ctx context.Context) func(req Containe // container lifecycle hooks. The created function will iterate over all the hooks and call them one by one. func containerHookFn(ctx context.Context, containerHook []ContainerHook) func(container Container) error { return func(ctr Container) error { - errs := make([]error, len(containerHook)) - for i, hook := range containerHook { - errs[i] = hook(ctx, ctr) + var errs []error + for _, hook := range containerHook { + if err := hook(ctx, ctr); err != nil { + errs = append(errs, err) + } } return errors.Join(errs...) @@ -532,65 +590,50 @@ func (p *DockerProvider) preCreateContainerHook(ctx context.Context, req Contain return nil } -// combineContainerHooks it returns just one ContainerLifecycle hook, as the result of combining -// the default hooks with the user-defined hooks. The function will loop over all the default hooks, -// storing each of the hooks in a slice, and then it will loop over all the user-defined hooks, -// appending or prepending them to the slice of hooks. The order of hooks is the following: -// - for Pre-hooks, always run the default hooks first, then append the user-defined hooks -// - for Post-hooks, always run the user-defined hooks first, then the default hooks +// combineContainerHooks returns a ContainerLifecycle hook as the result +// of combining the default hooks with the user-defined hooks. +// +// The order of hooks is the following: +// - Pre-hooks run the default hooks first then the user-defined hooks +// - Post-hooks run the user-defined hooks first then the default hooks func combineContainerHooks(defaultHooks, userDefinedHooks []ContainerLifecycleHooks) ContainerLifecycleHooks { - preCreates := []ContainerRequestHook{} - postCreates := []ContainerHook{} - preStarts := []ContainerHook{} - postStarts := []ContainerHook{} - postReadies := []ContainerHook{} - preStops := []ContainerHook{} - postStops := []ContainerHook{} - preTerminates := []ContainerHook{} - postTerminates := []ContainerHook{} - + // We use reflection here to ensure that any new hooks are handled. + var hooks ContainerLifecycleHooks + hooksVal := reflect.ValueOf(&hooks).Elem() + hooksType := reflect.TypeOf(hooks) for _, defaultHook := range defaultHooks { - preCreates = append(preCreates, defaultHook.PreCreates...) - preStarts = append(preStarts, defaultHook.PreStarts...) - preStops = append(preStops, defaultHook.PreStops...) - preTerminates = append(preTerminates, defaultHook.PreTerminates...) + defaultVal := reflect.ValueOf(defaultHook) + for i := 0; i < hooksType.NumField(); i++ { + if strings.HasPrefix(hooksType.Field(i).Name, "Pre") { + field := hooksVal.Field(i) + field.Set(reflect.AppendSlice(field, defaultVal.Field(i))) + } + } } - // append the user-defined hooks after the default pre-hooks - // and because the post hooks are still empty, the user-defined post-hooks - // will be the first ones to be executed + // Append the user-defined hooks after the default pre-hooks + // and because the post hooks are still empty, the user-defined + // post-hooks will be the first ones to be executed. for _, userDefinedHook := range userDefinedHooks { - preCreates = append(preCreates, userDefinedHook.PreCreates...) - postCreates = append(postCreates, userDefinedHook.PostCreates...) - preStarts = append(preStarts, userDefinedHook.PreStarts...) - postStarts = append(postStarts, userDefinedHook.PostStarts...) - postReadies = append(postReadies, userDefinedHook.PostReadies...) - preStops = append(preStops, userDefinedHook.PreStops...) - postStops = append(postStops, userDefinedHook.PostStops...) - preTerminates = append(preTerminates, userDefinedHook.PreTerminates...) - postTerminates = append(postTerminates, userDefinedHook.PostTerminates...) + userVal := reflect.ValueOf(userDefinedHook) + for i := 0; i < hooksType.NumField(); i++ { + field := hooksVal.Field(i) + field.Set(reflect.AppendSlice(field, userVal.Field(i))) + } } - // finally, append the default post-hooks + // Finally, append the default post-hooks. for _, defaultHook := range defaultHooks { - postCreates = append(postCreates, defaultHook.PostCreates...) - postStarts = append(postStarts, defaultHook.PostStarts...) - postReadies = append(postReadies, defaultHook.PostReadies...) - postStops = append(postStops, defaultHook.PostStops...) - postTerminates = append(postTerminates, defaultHook.PostTerminates...) + defaultVal := reflect.ValueOf(defaultHook) + for i := 0; i < hooksType.NumField(); i++ { + if strings.HasPrefix(hooksType.Field(i).Name, "Post") { + field := hooksVal.Field(i) + field.Set(reflect.AppendSlice(field, defaultVal.Field(i))) + } + } } - return ContainerLifecycleHooks{ - PreCreates: preCreates, - PostCreates: postCreates, - PreStarts: preStarts, - PostStarts: postStarts, - PostReadies: postReadies, - PreStops: preStops, - PostStops: postStops, - PreTerminates: preTerminates, - PostTerminates: postTerminates, - } + return hooks } func mergePortBindings(configPortMap, exposedPortMap nat.PortMap, exposedPorts []string) nat.PortMap { diff --git a/lifecycle_test.go b/lifecycle_test.go index 0faefa8769..91102ccf82 100644 --- a/lifecycle_test.go +++ b/lifecycle_test.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "reflect" "strings" "testing" "time" @@ -971,3 +972,105 @@ func lifecycleHooksIsHonouredFn(t *testing.T, prints []string) { require.Equal(t, expects, prints) } + +func Test_combineContainerHooks(t *testing.T) { + var funcID string + defaultContainerRequestHook := func(ctx context.Context, req ContainerRequest) error { + funcID = "defaultContainerRequestHook" + return nil + } + userContainerRequestHook := func(ctx context.Context, req ContainerRequest) error { + funcID = "userContainerRequestHook" + return nil + } + defaultContainerHook := func(ctx context.Context, container Container) error { + funcID = "defaultContainerHook" + return nil + } + userContainerHook := func(ctx context.Context, container Container) error { + funcID = "userContainerHook" + return nil + } + + defaultHooks := []ContainerLifecycleHooks{ + { + PreBuilds: []ContainerRequestHook{defaultContainerRequestHook}, + PostBuilds: []ContainerRequestHook{defaultContainerRequestHook}, + PreCreates: []ContainerRequestHook{defaultContainerRequestHook}, + PostCreates: []ContainerHook{defaultContainerHook}, + PreStarts: []ContainerHook{defaultContainerHook}, + PostStarts: []ContainerHook{defaultContainerHook}, + PostReadies: []ContainerHook{defaultContainerHook}, + PreStops: []ContainerHook{defaultContainerHook}, + PostStops: []ContainerHook{defaultContainerHook}, + PreTerminates: []ContainerHook{defaultContainerHook}, + PostTerminates: []ContainerHook{defaultContainerHook}, + }, + } + userDefinedHooks := []ContainerLifecycleHooks{ + { + PreBuilds: []ContainerRequestHook{userContainerRequestHook}, + PostBuilds: []ContainerRequestHook{userContainerRequestHook}, + PreCreates: []ContainerRequestHook{userContainerRequestHook}, + PostCreates: []ContainerHook{userContainerHook}, + PreStarts: []ContainerHook{userContainerHook}, + PostStarts: []ContainerHook{userContainerHook}, + PostReadies: []ContainerHook{userContainerHook}, + PreStops: []ContainerHook{userContainerHook}, + PostStops: []ContainerHook{userContainerHook}, + PreTerminates: []ContainerHook{userContainerHook}, + PostTerminates: []ContainerHook{userContainerHook}, + }, + } + expects := ContainerLifecycleHooks{ + PreBuilds: []ContainerRequestHook{defaultContainerRequestHook, userContainerRequestHook}, + PostBuilds: []ContainerRequestHook{userContainerRequestHook, defaultContainerRequestHook}, + PreCreates: []ContainerRequestHook{defaultContainerRequestHook, userContainerRequestHook}, + PostCreates: []ContainerHook{userContainerHook, defaultContainerHook}, + PreStarts: []ContainerHook{defaultContainerHook, userContainerHook}, + PostStarts: []ContainerHook{userContainerHook, defaultContainerHook}, + PostReadies: []ContainerHook{userContainerHook, defaultContainerHook}, + PreStops: []ContainerHook{defaultContainerHook, userContainerHook}, + PostStops: []ContainerHook{userContainerHook, defaultContainerHook}, + PreTerminates: []ContainerHook{defaultContainerHook, userContainerHook}, + PostTerminates: []ContainerHook{userContainerHook, defaultContainerHook}, + } + + ctx := context.Background() + ctxVal := reflect.ValueOf(ctx) + var req ContainerRequest + reqVal := reflect.ValueOf(req) + container := &DockerContainer{} + containerVal := reflect.ValueOf(container) + + got := combineContainerHooks(defaultHooks, userDefinedHooks) + + // Compare for equal. This can't be done with deep equals as functions + // are not comparable so we us the unique value stored in funcID when + // the function is called to determine if they are the same. + gotVal := reflect.ValueOf(got) + gotType := reflect.TypeOf(got) + expectedVal := reflect.ValueOf(expects) + for i := 0; i < gotVal.NumField(); i++ { + fieldName := gotType.Field(i).Name + gotField := gotVal.Field(i) + expectedField := expectedVal.Field(i) + require.Equalf(t, expectedField.Len(), 2, "field %q not setup len expected %d got %d", fieldName, 2, expectedField.Len()) //nolint:testifylint // False positive. + require.Equalf(t, expectedField.Len(), gotField.Len(), "field %q len expected %d got %d", fieldName, gotField.Len(), expectedField.Len()) + for j := 0; j < gotField.Len(); j++ { + gotIndex := gotField.Index(j) + expectedIndex := expectedField.Index(j) + var gotID string + if gotIndex.Type().Name() == "ContainerRequestHook" { + gotIndex.Call([]reflect.Value{ctxVal, reqVal}) + gotID = funcID + expectedIndex.Call([]reflect.Value{ctxVal, reqVal}) + } else { + gotIndex.Call([]reflect.Value{ctxVal, containerVal}) + gotID = funcID + expectedIndex.Call([]reflect.Value{ctxVal, containerVal}) + } + require.Equalf(t, funcID, gotID, "field %q[%d] func expected %s got %s", fieldName, j, funcID, gotID) + } + } +}